In [1]:
import os
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import ToTensor
import numpy as np
from PIL import Image


def is_image_valid(image_path):
    try:
        img = Image.open(image_path)
        img.verify()
        return True
    except:
        return False

class VideoDataset(Dataset):
    def __init__(self, directory, labeled=True, transform=None):
        self.directory = directory
        self.video_dirs = sorted([d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))])
        self.labeled = labeled
        self.transform = transform

    def __len__(self):
        return len(self.video_dirs)

    def __getitem__(self, idx):
        video_dir = os.path.join(self.directory, self.video_dirs[idx])
        frames = []
        try:
            for i in range(22):
                image_path = os.path.join(video_dir, f"image_{i}.png")
                image = Image.open(image_path).convert("RGB")
                frames.append(ToTensor()(image))
            video = torch.stack(frames)
        except Exception as e:
            print(video_dir)
            print(e)
            
            return None
        # for i in range(22):
        #     image_path = os.path.join(video_dir, f"image_{i}.png")
        #     if is_image_valid(image_path):
        #       image = Image.open(image_path).convert("RGB")
        #       frames.append(ToTensor()(image))
        
        
        if self.labeled:
            mask_path = os.path.join(video_dir, "mask.npy")
            mask_data = np.load(mask_path)
            mask = torch.tensor(mask_data.reshape(22, 160, 240), dtype=torch.long)
            return video, mask
        else:
            return video



path_to_dataset = "/dataset/"
train_path = os.path.join(path_to_dataset, "train")
unlabeled_path = os.path.join(path_to_dataset, "unlabeled")
val_path = os.path.join(path_to_dataset, "val")

train_dataset = VideoDataset(train_path, transform=transforms.ToTensor())
unlabeled_dataset = VideoDataset(unlabeled_path, labeled=False, transform=transforms.ToTensor())
val_dataset = VideoDataset(val_path, transform=transforms.ToTensor())

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
unlabeled_dataloader = DataLoader(unlabeled_dataset, batch_size=8, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)


In [2]:
class MoCo(nn.Module):
    def __init__(self, encoder, queue_size=8192, temperature=0.07):
        super(MoCo, self).__init__()
        self.encoder_k = encoder
        self.encoder_q = encoder
        self.queue_size = queue_size
        self.temperature = temperature

        self.register_buffer("queue", torch.randn(128, queue_size))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self, m=0.999):
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * m + param_q.data * (1. - m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.queue_size % batch_size == 0, "queue size should be divisible by batch size"

        self.queue[:, ptr:ptr + batch_size] = keys.t()
        ptr = (ptr + batch_size) % self.queue_size

        self.queue_ptr[0] = ptr

    def forward(self, x_q, x_k):
        _, _, q = self.encoder_q(x_q) # Change this line to get the proper tensor output
        _, _, k = self.encoder_k(x_k) # Change this line to get the proper tensor output

        k = nn.functional.normalize(k, dim=1)
        self._dequeue_and_enqueue(k)

        q = self.encoder_q(x_q)
        q = nn.functional.normalize(q, dim=1)

        l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
        l_neg = torch.matmul(q, self.queue.clone().detach())

        logits = torch.cat([l_pos, l_neg], dim=1)
        logits /= self.temperature

        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        return logits, labels


In [11]:
class VideoEncoder(nn.Module):
    def __init__(self, in_channels):
        super(VideoEncoder, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(1, 1, 1))
        self.conv3 = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.relu3 = nn.ReLU()


    def forward(self, x):
        x1 = self.relu1(self.conv1(x))
        x2 = self.pool(self.relu2(self.conv2(x1)))
        x3 = self.pool(self.relu3(self.conv3(x2)))
        return x1, x2, x3

In [8]:
from torchvision.transforms import RandomHorizontalFlip, Compose

def data_augmentation():
    return Compose([
        RandomHorizontalFlip(p=0.5),
        
    ])

from torchvision.transforms.functional import to_pil_image

def augment_video(batch_video, augmentation):
    augmented_batch_video = []
    for video in batch_video:
        augmented_video = []
        for frame in video:
            frame = to_pil_image(frame)  # Convert tensor to PIL Image
            augmented_frame = augmentation(frame)
            augmented_frame = ToTensor()(augmented_frame)  # Change back to (C, H, W)
            augmented_video.append(augmented_frame)
        augmented_batch_video.append(torch.stack(augmented_video))
    return torch.stack(augmented_batch_video)

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = VideoEncoder(in_channels=22).to(device)
moco = MoCo(encoder).to(device)
optimizer = optim.Adam(moco.parameters(), lr=0.001)

def train_moco(dataloader, model, optimizer, epochs=10):
    model.train()
    augmentation = data_augmentation()
    for epoch in range(epochs):
        for idx, video in enumerate(dataloader):
            data_q = video.to(device)
            data_k = augment_video(video, augmentation).to(device)

            logits, labels = moco(data_q, data_k)

            loss = nn.CrossEntropyLoss()(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            model._momentum_update_key_encoder()

            if idx % 10 == 0:
                print(f"Epoch: {epoch}, Step: {idx}, Loss: {loss.item()}")

random.seed(0)
train_indices = list(range(len(unlabeled_dataset)))
random.shuffle(train_indices)
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)

unlabeled_train_dataloader = DataLoader(unlabeled_dataset, batch_size=8, sampler=train_sampler, num_workers=1)

train_moco(unlabeled_train_dataloader, moco, optimizer)


AttributeError: 'tuple' object has no attribute 'norm'

In [5]:
class UNetDecoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetDecoder, self).__init__()
        self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, in_channels // 2, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(in_channels // 2, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

class SegmentationModel(nn.Module):
    def __init__(self, encoder, num_classes):
        super(SegmentationModel, self).__init__()
        self.encoder = encoder
        
        self.middle = nn.Sequential(
            nn.Conv3d(512, 1024, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(1024, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.up1 = UNetDecoder(1024, 256)
        self.up2 = UNetDecoder(512, 128)
        self.up3 = UNetDecoder(256, 64)
        self.up4 = UNetDecoder(128, num_classes)
        
    def forward(self, x):
        x1, x2, x3, x4, x5 = self.encoder(x)
        x = self.middle(x5)
        x = self.up1(x, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        return x


In [6]:
seg_model = SegmentationModel(moco.encoder_q, num_classes=24).to(device)
seg_optimizer = optim.Adam(seg_model.parameters(), lr=0.001)

def train_segmentation(dataloader, model, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        for idx, (data, labels) in enumerate(dataloader):
            data = data.to(device)
            labels = labels.to(device)

            logits = model(data)

            loss = nn.CrossEntropyLoss()(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if idx % 10 == 0:
                print(f"Epoch: {epoch}, Step: {idx}, Loss: {loss.item()}")

train_segmentation(train_dataloader, seg_model, seg_optimizer)


NameError: name 'moco' is not defined

In [None]:
import torch.nn.functional as F

class UNet3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3D, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv3d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.middle = nn.Sequential(
            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.Conv3d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(64, out_channels, kernel_size=3, padding=1)
        )

        self.down = nn.MaxPool3d(2, stride=2)
        self.up = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)

    def forward(self, x):
        x1 = self.encoder(x)
        x2 = self.middle(self.down(x1))
        x3 = self.decoder(self.up(x2))
        return x3


In [None]:
import numpy as np
from sklearn.cluster import KMeans

def mc_jepa_loss(logits, pseudo_labels, alpha=0.1):
    pred = F.softmax(logits, dim=1)
    pred = pred.view(pred.size(0), pred.size(1), -1)
    cluster_centers = torch.zeros((pred.size(1), pred.size(2)), device=logits.device)
    
    for i in range(pred.size(1)):
        cluster_centers[i] = torch.mean(pred[pseudo_labels == i], dim=0)
    
    joint_entropy = -torch.mean(torch.sum(pred * torch.log(torch.clamp(cluster_centers[pseudo_labels], 1e-10)), dim=1))
    cross_entropy = -torch.mean(torch.sum(pred * torch.log(torch.clamp(pred, 1e-10)), dim=1))

    loss = (1 - alpha) * cross_entropy + alpha * joint_entropy
    return loss

def mc_jepa_pseudo_labeling(logits):
    pred = F.softmax(logits, dim=1)
    pred_np = pred.cpu().numpy().reshape(pred.size(0), pred.size(1), -1)
    kmeans = KMeans(n_clusters=pred.size(1), random_state=0)
    pseudo_labels = []

    for i in range(pred_np.shape[0]):
        kmeans.fit(pred_np[i].T)
        pseudo_labels.append(kmeans.labels_)

    pseudo_labels = np.array(pseudo_labels).reshape(pred.size(0), *pred.size()[2:])
    return torch.tensor(pseudo_labels, dtype=torch.long, device=logits.device)


In [None]:
unet3d = UNet3D(1, num_classes).to(device)
optimizer = optim.Adam(unet3d.parameters(), lr=0.001)

def train_mc_jepa(dataloader, model, optimizer, alpha=0.1, epochs=10):
    model.train()
    for epoch in range(epochs):
        for idx, data in enumerate(dataloader):
            data = data.to(device)
            optimizer.zero_grad()
            logits = model(data)
            
            pseudo_labels = mc_jepa_pseudo_labeling(logits)
            loss = mc_jepa_loss(logits, pseudo_labels, alpha)
            
            loss.backward()
            optimizer.step()

            if idx % 10 == 0:
                print(f"Epoch: {epoch}, Step: {idx}, Loss: {loss.item()}")

train_mc_jepa(unlabeled_train_dataloader, unet3d, optimizer)


In [None]:
def train_finetune(dataloader, model, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        for idx, (data, labels) in enumerate(dataloader):
            data = data.to(device)
            labels = labels.to(device)

            logits = model(data)
            loss = nn.CrossEntropyLoss()(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if idx % 10 == 0:
                print(f"Epoch: {epoch}, Step: {idx}, Loss: {loss.item()}")

train_finetune(train_dataloader, unet3d, optimizer)
