In [None]:
# import the project directory here to find the emtl package
import os, sys
project_dir = os.path.abspath('..')

# if the kernel wasn't restarted, the folder might still be there
if project_dir not in sys.path: 
    sys.path.append(project_dir)

In [1]:
import torch
from torchvision import models as M
from torchvision import datasets as D
from torchvision import transforms as T

# EMTL Library Imports
from emtl import Task, Trainer
from emtl.algorithms import SequentialTraining

device = 'cuda'

In [None]:
class DeepLabHeadPipeline(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, device: str = 'cpu') -> None:
        super().__init__()
        self.head = M.segmentation.deeplabv3.DeepLabHead(
            in_channels=in_features, num_classes=out_features).to(device)

    def forward(self, x: torch.Tensor, original_shape: tuple[int, int]) -> dict[str, torch.Tensor]:
        x = self.head(x)
        x = torch.nn.functional.interpolate(x, size=original_shape, mode="bilinear", align_corners=False)

        return x

In [None]:
# make a ResNet50 backbone removing the last two layers (fc and avgpool)
backbone = M.resnet50(replace_stride_with_dilation=[False, True, True]).to(device)
backbone = torch.nn.Sequential(*(list(backbone.children())[:-2]))
head = DeepLabHeadPipeline(in_features=2048, out_features=21, device=device)

shape = (520,520)
tinput  = T.Compose([T.Resize(shape), T.ToTensor()])
ttarget = lambda x : T.Compose([T.Resize(shape), T.PILToTensor()])(x).squeeze(0).long()
testset = D.VOCSegmentation(root='data', image_set='val', transform=tinput, target_transform=ttarget)

In [None]:
VOC_segmentation_task = Task(
    name = 'voc_seg',
    head = head,
    trainset = testset, 
    testset = testset,
    dataloader_params = {'batch_size': 4, 'num_workers': 4, 'pin_memory': True, 'drop_last': True},
    criterion = torch.nn.CrossEntropyLoss(ignore_index=255),
    optimizer_fn = torch.optim.Adam,
    scheduler_fn = torch.optim.lr_scheduler.ReduceLROnPlateau
)

In [None]:
trainer = Trainer(
    backbone = backbone,
    tasks = [VOC_segmentation_task],
    algorithm = SequentialTraining(epochs=5),
    config='config.ini'
)

# train the model
trainer.launch()