In [5]:
import os
import glob
import json
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2

import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.transforms.functional import to_pil_image
from torchvision.datasets import ImageFolder
from PIL import Image

# from google.colab import drive

from collections import Counter

In [6]:

# --- Double Convolution Block ---
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)

# --- Downsampling Block (Encoder) ---
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.down(x)

# --- Upsampling Block (Decoder) ---
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # pad if necessary
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

# --- Output Layer ---
class OutConv(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

# --- Final UNet Model ---
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=50, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)

        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)

        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


In [7]:
model = UNet(n_channels=3, n_classes=50)
output = model(torch.randn(1, 3, 256, 256))  # output shape: [1, 50, 256, 256]

In [None]:
import os
import glob
import json
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

class ImageNetSubsetSegmentationDataset(Dataset):
    def __init__(self, json_path, image_size=(224, 224), mode='both'):
        """
        mode: 'classification' | 'segmentation' | 'both'
        """
        self.image_dir = "train-semi"
        self.mask_dir = "train-semi-segmentation"
        self.mode = mode
        self.image_size = image_size

        with open(json_path, "r") as f:
            self.samples = json.load(f)

        self.aug = A.Compose([
            A.RandomHorizontalFlip(p=0.3),
            A.Rotate(limit=5, p=0.3),
            A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05, p=0.3),
            A.Resize(image_size[0], image_size[1], interpolation=1),  # bilinear for image
            A.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        entry = self.samples[idx]
        image_rel = entry["image"]
        mask_rel = entry["mask"]
        class_id = entry["class_id"]

        # Resolve paths
        image_path = os.path.join(self.image_dir, image_rel)
        if not os.path.exists(image_path):
            pattern = os.path.splitext(image_path)[0] + ".*"
            matches = glob.glob(pattern)
            if matches:
                image_path = matches[0]
            else:
                raise FileNotFoundError(f"Image not found: {image_rel}")

        mask_path = os.path.join(self.mask_dir, mask_rel)
        if not os.path.exists(mask_path):
            pattern = os.path.splitext(mask_path)[0] + ".*"
            matches = glob.glob(pattern)
            if matches:
                mask_path = matches[0]
            else:
                raise FileNotFoundError(f"Mask not found: {mask_rel}")

        # Load image and mask
        image = np.array(Image.open(image_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("RGB"))
        if mask.ndim == 3 and mask.shape[2] == 3:
            mask = mask[:, :, 0]  # extract single channel

        # Apply transforms
        augmented = self.aug(image=image, mask=mask)
        image = augmented["image"]
        mask = augmented["mask"].long()  # final shape: [H, W]

        if self.mode == "classification":
            return image, class_id

        if self.mode == "segmentation":
            return image, mask

        if self.mode == "both":
            return image, class_id, mask


In [10]:
transform = SegmentationTransform(size=(224, 224))

dataset = SegmentationDataset(
    json_path="train_semi_annotations_with_seg_ids.json",
    image_root="train-semi",
    mask_root="train-semi-segmentation",
    transform=transform
)

dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)


In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(n_channels=3, n_classes=50).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [12]:
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for images, masks in dataloader:
        images = images.to(device)
        masks = masks.to(device)

        outputs = model(images)  # shape: [B, 50, H, W]
        loss = criterion(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss / len(dataloader):.4f}")


RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size: : [8]