In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from torchvision import models
from torch.utils.data import DataLoader, Dataset
import torch.utils.data as utils
from torchvision import transforms
import torch.nn.functional as F
from torch.optim import Adam
import cv2
from torchmetrics import JaccardIndex

In [2]:
# from torchvision.models import ResNet50_Weights

In [3]:
"""
MIT License

Copyright (c) 2020 Phil Wang
https://github.com/lucidrains/byol-pytorch/

Adjusted to de-couple for data loading, parallel training
"""
# BYOL for SSL training, not using yet.
import copy
import random
from functools import wraps

import torch
from torch import nn
import torch.nn.functional as F

# helper functions


def default(val, def_val):
    return def_val if val is None else val


def flatten(t):
    return t.reshape(t.shape[0], -1)


def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance

            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance

        return wrapper

    return inner_fn


# loss fn


def loss_fn(x, y):
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)


# augmentation utils


class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)


# exponential moving average


class EMA:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new


def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(
        current_model.parameters(), ma_model.parameters()
    ):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)


# MLP class for projector and predictor


class MLP(nn.Module):
    def __init__(self, dim, projection_size, hidden_size=4096):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, projection_size),
        )

    def forward(self, x):
        return self.net(x)


# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets


class NetWrapper(nn.Module):
    def __init__(self, net, projection_size, projection_hidden_size, layer=-2):
        super().__init__()
        self.net = net
        self.layer = layer

        self.projector = None
        self.projection_size = projection_size
        self.projection_hidden_size = projection_hidden_size

        self.hidden = None
        self.hook_registered = False

    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None

    def _hook(self, _, __, output):
        self.hidden = flatten(output)

    def _register_hook(self):
        layer = self._find_layer()
        assert layer is not None, f"hidden layer ({self.layer}) not found"
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    @singleton("projector")
    def _get_projector(self, hidden):
        _, dim = hidden.shape
        projector = MLP(dim, self.projection_size, self.projection_hidden_size)
        return projector.to(hidden)

    def get_representation(self, x):
        if not self.hook_registered:
            self._register_hook()

        if self.layer == -1:
            return self.net(x)

        _ = self.net(x)
        hidden = self.hidden
        self.hidden = None
        assert hidden is not None, f"hidden layer {self.layer} never emitted an output"
        return hidden

    def forward(self, x):
        representation = self.get_representation(x)
        projector = self._get_projector(representation)
        projection = projector(representation)
        return projection


# main class


class BYOL(nn.Module):
    def __init__(
        self,
        net,
        image_size,
        hidden_layer=-2,
        projection_size=256,
        projection_hidden_size=4096,
        augment_fn=None,
        moving_average_decay=0.99,
    ):
        super().__init__()

        self.online_encoder = NetWrapper(
            net, projection_size, projection_hidden_size, layer=hidden_layer
        )
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MLP(
            projection_size, projection_size, projection_hidden_size
        )

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size), torch.randn(2, 3, image_size, image_size))

    @singleton("target_encoder")
    def _get_target_encoder(self):
        target_encoder = copy.deepcopy(self.online_encoder)
        return target_encoder

    def reset_moving_average(self):
        del self.target_encoder
        self.target_encoder = None

    def update_moving_average(self):
        assert (
            self.target_encoder is not None
        ), "target encoder has not been created yet"
        update_moving_average(
            self.target_ema_updater, self.target_encoder, self.online_encoder
        )

    def forward(self, image_one, image_two):
        online_proj_one = self.online_encoder(image_one)
        online_proj_two = self.online_encoder(image_two)

        online_pred_one = self.online_predictor(online_proj_one)
        online_pred_two = self.online_predictor(online_proj_two)

        with torch.no_grad():
            target_encoder = self._get_target_encoder()
            target_proj_one = target_encoder(image_one)
            target_proj_two = target_encoder(image_two)

        loss_one = loss_fn(online_pred_one, target_proj_two.detach())
        loss_two = loss_fn(online_pred_two, target_proj_one.detach())

        loss = loss_one + loss_two
        return loss.mean()

In [4]:
import torch
import torch.nn as nn
import torchvision

# Unet basic model structures.
class ConvBlock(nn.Module):
    """
    Helper module that consists of a Conv -> BN -> ReLU
    """

    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.with_nonlinearity = with_nonlinearity

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.with_nonlinearity:
            x = self.relu(x)
        return x


class Bridge(nn.Module):
    """
    This is the middle layer of the UNet which just consists of some
    """

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bridge = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x):
        return self.bridge(x)


class UpBlockForUNetWithResNet50(nn.Module):
    """
    Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock
    """

    def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None,
                 upsampling_method="conv_transpose"):
        super().__init__()

        if up_conv_in_channels == None:
            up_conv_in_channels = in_channels
        if up_conv_out_channels == None:
            up_conv_out_channels = out_channels

        if upsampling_method == "conv_transpose":
            self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
        elif upsampling_method == "bilinear":
            self.upsample = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
            )
        self.conv_block_1 = ConvBlock(in_channels, out_channels)
        self.conv_block_2 = ConvBlock(out_channels, out_channels)

    def forward(self, up_x, down_x):
        """

        :param up_x: this is the output from the previous up block
        :param down_x: this is the output from the down block
        :return: upsampled feature map
        """
        x = self.upsample(up_x)
        x = torch.cat([x, down_x], 1)
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        return x


class UNetWithResnet50Encoder(nn.Module):
    DEPTH = 6

    def __init__(self, n_classes=49):
        super().__init__()
#         byol_model = torch.load('/home/yz10727/data/best_model_SSL_1.pth')
#         resnet=byol_model.target_encoder.net
#         resnet.requires_grad = False
        resnet = models.resnet50(weights=None)
#         resnet = torch.load('/scratch/yz10727/data/Pre_resnet50.pth')
#         for param in resnet.parameters():
#                 param.requires_grad = False
#         resnet_up = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        down_blocks = []
        up_blocks = []
        self.input_block = nn.Sequential(*list(resnet.children()))[:3]
        self.input_pool = list(resnet.children())[3]
        for bottleneck in list(resnet.children()):
            if isinstance(bottleneck, nn.Sequential):
                down_blocks.append(bottleneck)
        self.down_blocks = nn.ModuleList(down_blocks)
        self.bridge = Bridge(2048, 2048)
        
        up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024))
        up_blocks.append(UpBlockForUNetWithResNet50(1024, 512))
        up_blocks.append(UpBlockForUNetWithResNet50(512, 256))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128,
                                                    up_conv_in_channels=256, up_conv_out_channels=128))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64,
                                                    up_conv_in_channels=128, up_conv_out_channels=64))

        self.up_blocks = nn.ModuleList(up_blocks)

        self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1)

    def forward(self, x, with_output_feature_map=False):
        pre_pools = dict()
        pre_pools[f"layer_0"] = x
        x = self.input_block(x)
        pre_pools[f"layer_1"] = x
        x = self.input_pool(x)
        #################UPDATE######################
#         with torch.no_grad():
        for i, block in enumerate(self.down_blocks, 2):
            x = block(x)
            if i == (UNetWithResnet50Encoder.DEPTH - 1):
                continue
            pre_pools[f"layer_{i}"] = x

        x = self.bridge(x)
        #############END UPDATE######################
        for i, block in enumerate(self.up_blocks, 1):
            key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}"
            x = block(x, pre_pools[key])
        output_feature_map = x
        x = self.out(x)
        del pre_pools
        if with_output_feature_map:
            return x, output_feature_map
        else:
            return x

In [5]:
import torch.utils.data as data_utils
import os
from data import UnlabeledDataset,LabeledDataset,ValidationDataset
# Load the data with Daniel's data.py

dataset = LabeledDataset('/scratch/py2050/Dataset_Student/')
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=1)

val_dataset = ValidationDataset('/scratch/py2050/Dataset_Student/')
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=True, num_workers=1)

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = UNetWithResnet50Encoder(n_classes=49).to(device)
# model = torch.load('best_model_unet_50.pth') for restart training, remember to modify
criterion = criterion = nn.CrossEntropyLoss()
# parameter chosen from YOLOv8 default value.
optim = Adam(model.parameters(), lr=0.0003,weight_decay=0.0001) 

In [7]:
num_epochs = 40 # num of epochs, modify based on node time
best_train_acc = 1
best_val_acc = 0  #Use IOU 
jaccard = JaccardIndex(task="multiclass", num_classes=49).to(device)  
model_path = '/scratch/py2050/best_model_unet_50.pth' # your model path, remember to modify

In [None]:
from tqdm import tqdm

for epoch in range(1, num_epochs+1):
    print("Training epoch: ", epoch)
    train_loss = 0   
    val_IoU_accuracy = 0 
    model.train()

    for data in tqdm(train_dataloader): 
        input, label = data
        input, label = input.to(device), label.to(device)
        input = input.reshape(-1,input.shape[2],input.shape[3],input.shape[4])
        label = label.reshape(-1,label.shape[2],label.shape[3])
        outputs = model(input)
        outputs = F.interpolate(outputs, size=(160, 240), mode='bilinear', align_corners=False)
        loss = criterion(outputs, label.long())
        loss.backward()   
        optim.step()                                               
        optim.zero_grad()                                           
        train_loss += loss.item()  
#         break
    train_loss /= (len(train_dataloader.dataset) * 22)   
    
    print("Validating epoch: ", epoch)
    for idx, data in enumerate(tqdm(val_dataloader)):
#         print(idx)
        val_input, val_label = data
        input, label = val_input.to(device), val_label.to(device)
        input = input.reshape(-1,input.shape[2],input.shape[3],input.shape[4])
        label = label.reshape(-1,label.shape[2],label.shape[3])
        outputs = model(input)
        outputs = F.interpolate(outputs, size=(160, 240), mode='bilinear', align_corners=False)
        output = nn.LogSoftmax()(outputs)
        output = torch.argmax(output, dim=1)
#         print(output.shape, label.shape)
        for i in range(label.shape[0]):
#             print(output[i].squeeze(0).shape, output[i].shape)
            jac = jaccard(output[i], label[i].to(device))
            val_IoU_accuracy += jac
        if idx == 49: # shorten val time, val on 22*50*2 images
            break
    val_IoU_accuracy /= (22*50*2)  # remember to modify based on sample val length
    
    print("Epoch{}: Training Loss:{:.6f}; Val IoU: {:.6f}.\n".format(epoch, train_loss, val_IoU_accuracy))
    if best_val_acc < val_IoU_accuracy:
        best_val_acc = val_IoU_accuracy
        torch.save(model, model_path)  
        print('Best model saved')

Training epoch:  1


 44%|████▍     | 221/500 [10:48<13:15,  2.85s/it]

In [None]:
model=torch.load('/scratch/py2050/best_model_unet_50.pth') # currently istrained 50-class model, remember to modify

num_check = 50 # video index to display, choose from 0~999
frame_check = 1 # frame index to display, choose from 0~21


In [None]:
input = dataset[num_check][0][frame_check].unsqueeze(0)
input = input.to(device)

# print(dataset[num_check][0][frame_check][0][159])

output = model(input) 
outputs = nn.LogSoftmax()(output)
outputs = torch.argmax(outputs, dim=1)
outputs = transforms.Resize((160, 240), interpolation=transforms.InterpolationMode.NEAREST)(outputs)
print(outputs.shape)

outputs=outputs.squeeze(0).cpu().numpy()
print(outputs.shape)
plt.imshow(outputs)
plt.axis('off')  
plt.title('Predicted unresized Image')

In [None]:
input = dataset[num_check][0][frame_check]
input = transforms.Resize((160, 240), interpolation=transforms.InterpolationMode.NEAREST)(input)

plt.imshow(input.permute(1, 2, 0))
plt.axis('off')  
plt.title('Actual unresized Normalized Image')

In [None]:
input_truth_label = dataset[num_check][1][frame_check]

plt.imshow(input_truth_label)
plt.axis('off')  
plt.title('Ground Truth Label Image')

In [None]:
jac = jaccard(torch.Tensor(outputs).to(device), torch.Tensor(input_truth_label).to(device))
jac