In [1]:
from datasets import load_dataset, Image
from torch.utils.data import DataLoader
from mit_semseg.models.resnet import resnet50
from mit_semseg.models import ModelBuilder, SegmentationModule 
import PIL
import torch
import torch.nn as nn

In [2]:
def prep_dataset(sample):
    img = sample['image']
    label = sample['annotation']
    sample['width'] = img.width
    sample['height'] = img.height

    if (img.mode == 'L') | (img.mode == 'CMYK') | (img.mode == 'RGBA'):
        rgbimg = PIL.Image.new("RGB", img.size)
        rgbimg.paste(img)
        img = rgbimg

    img = img.resize((256,256),resample=PIL.Image.LANCZOS)
    label = label.resize((256,256),resample=PIL.Image.LANCZOS)
    
    sample['image'] = Image().encode_example(img)
    sample['annotation'] = Image().encode_example(label)
    return sample

In [12]:
dataset = load_dataset("scene_parse_150",split='train[0:1024]')
dataset = dataset.map(prep_dataset)
dataset.set_format("torch")
encoder = ModelBuilder().build_encoder(
    arch='resnet50dilated',
    fc_dim=2048,
    weights='')
decoder = ModelBuilder().build_decoder(
    arch='ppm',
    fc_dim=2048,
    num_class=150,
    weights='',
    use_softmax=False)
crit = nn.NLLLoss(ignore_index=-1)
sm = SegmentationModule(encoder,decoder,crit)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [4]:
for i,batch in enumerate(dataloader):
    x = batch['image']
    x = x.permute(0, 3, 1, 2)
    x = x.to(torch.float)    
    break

In [9]:
y = sm.encoder(x, return_feature_maps=False)
y[0].shape

torch.Size([64, 2048, 32, 32])

In [13]:
pred = decoder(y)

In [15]:
pred.shape

torch.Size([64, 150, 32, 32])