In [1]:
import torch
import os
import pandas as pd

In [2]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

'cuda'

# Custom Dataset

In [3]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torchvision.transforms.functional as transforms_F
import pandas as pd

In [4]:
class_df = pd.read_csv('/kaggle/input/camvid/CamVid/class_dict.csv')
# Create a dictionary that maps rgb value to 32 CamVid's class indices
RGB2label_dict = {
    (row['r'], row['g'], row['b']): idx
    for idx, row in class_df.iterrows()
}
label2RGB_dict = {
    v: k for k, v in RGB2label_dict.items()
}

In [5]:
class CamVidDataset(Dataset):
    def __init__(self, img_dir: str, label_dir: str, augmentation: bool=False):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.augmentation = augmentation
        self.img_files = os.listdir(self.img_dir)
        self.label_files = os.listdir(self.label_dir)

        self.transform = transforms.Compose([
            transforms.Resize((384, 480)),
            transforms.ToTensor()
        ])
        
    def __len__(self):
        return len(self.img_files)

    def _augment(self, image: torch.Tensor, label: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        # Horizontal flip with p=0.5
        if torch.randn(1) > 0.5:
            image = transforms_F.hflip(image)
            label = transforms_F.hflip(label)
        # Pad for cropping
        image = transforms_F.pad(image, (10, 10, 10, 10))
        label = transforms_F.pad(label, (10, 10, 10, 10))
        # RandomCrop
        i, j, h, w = transforms.RandomCrop.get_params(image, output_size=(384, 480))
        image = transforms_F.crop(image, i, j, h, w)
        label = transforms_F.crop(label, i, j, h, w)

        image = transforms.ColorJitter(brightness=0.1, contrast=0, saturation=0, hue=0.2)(image)
        return image, label
        
    def __getitem__(self, idx):
        img_file = self.img_files[idx]
        label_file = self.label_files[idx]

        img_path = os.path.join(self.img_dir, img_file)
        label_path = os.path.join(self.label_dir, label_file)

        image = Image.open(img_path)
        label = Image.open(label_path)

        # Transform
        image = self.transform(image).to(DEVICE)
        label = self.transform(label).to(DEVICE)

        # If augmentation is on, apply augmentation
        if self.augmentation:
            image, label = self._augment(image, label)

        # Masking label image pixel by pixel
        label = label.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
        label = (label * 255).int() # Scale back to 0~255 as torch.ToTensor() scaled the image to 0~1
        masked_label = torch.zeros(label.size(0), label.size(1), dtype=torch.uint8, device=DEVICE)
        for rgb, idx in RGB2label_dict.items(): # Mask the pixels for every class type
            rgb_tensor = torch.tensor(rgb, device=DEVICE)
            masked_label[(label == rgb_tensor).all(axis=-1)] = idx

        return image, masked_label.long()
        

In [6]:
train_img_dir = '/kaggle/input/camvid/CamVid/train'
train_label_dir ='/kaggle/input/camvid/CamVid/train_labels'
train_dataset = CamVidDataset(train_img_dir, train_label_dir, False)

val_img_dir = '/kaggle/input/camvid/CamVid/val'
val_label_dir = '/kaggle/input/camvid/CamVid/val_labels'
val_dataset = CamVidDataset(val_img_dir, val_label_dir, False)

In [7]:
train_dataset[0][0].shape

torch.Size([3, 384, 480])

In [8]:
BATCH_SIZE = 8
train_dataloader = DataLoader(train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True)
val_dataloader = DataLoader(val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=False)

# Model

In [9]:
import torch
from torch import nn
import torchvision.models as models

class FCN_8s(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        vgg16 = models.vgg16_bn(weights="IMAGENET1K_V1", progress=True)

        # Use the features from vgg16
        self.features = vgg16.features

        # Replace the classifier with convolutional layers
        self.score_pool4 = nn.Sequential(
            nn.Conv2d(512, num_classes, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
        self.score_pool3 = nn.Sequential(
            nn.Conv2d(256, num_classes, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )

        self.score_fr = nn.Sequential(
            nn.Conv2d(512, 4096, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Dropout2d(),
            nn.Conv2d(4096, 4096, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Dropout2d(),
            nn.Conv2d(4096, num_classes, kernel_size=1)
        )

        # Transposed convolution layers for upsampling
        '''
        score_fr*2 means score_fr upsampled by factor of 2 using Transposed Convolution
        '''
        # self.upscore_pool5 = nn.ConvTranspose2d(
        #     num_classes, num_classes, kernel_size=4, stride=2, padding=1
        # ) # Upsamples the score_fr by factor of 2
        # self.upscore_pool4 = nn.ConvTranspose2d(
        #     num_classes, num_classes, kernel_size=4, stride=2, padding=1
        # ) # Upsamples the (score_fr*2 + score_pool4) by factor of 2
        # self.upscore_pool3 = nn.ConvTranspose2d(
        #     num_classes, num_classes, kernel_size=16, stride=8, padding=4
        # ) # Upsamples the [(score_fr*2 + score_pool4)*2 + score_pool3] by factor of 8
        self.upscore_pool5 = nn.ConvTranspose2d(
            num_classes, num_classes, kernel_size=2, stride=2, bias=False
        ) # Upsamples the score_fr by factor of 2
        self.upscore_pool4 = nn.ConvTranspose2d(
            num_classes, num_classes, kernel_size=2, stride=2, bias=False
        ) # Upsamples the (score_fr*2 + score_pool4) by factor of 2
        self.upscore_pool3 = nn.ConvTranspose2d(
            num_classes, num_classes, kernel_size=8, stride=8, bias=False
        ) # Upsamples the [(score_fr*2 + score_pool4)*2 + score_pool3] by factor of 8

    def forward(self, x):
        # Store intermediate outputs for skip connections
        pool3 = None
        pool4 = None

        '''input x: [N, 3, 384, 480] '''
        # Forward pass through VGG16 features
        for i in range(len(self.features)):
            x = self.features[i](x) # Feed forwarding the previous output to each layer coming next
            if i == 23: # After pool_3 layer passed
                pool3 = x 
                ''' pool3: [N, 256, 48, 60] '''
            elif i == 33: # After pool_4 layer passed
                pool4 = x 
                ''' pool4: [N, 512, 24, 30] '''
            elif i == 43: 
                ''' x: [N, 512, 12, 15] '''
                break;

        # Classify the features
        # x is now the output from the last pooling layer(pool_5) of vgg16_bn
        ''' x: [N, 32, 12, 15] '''
        x = self.score_fr(x) # (N, num_classes, H/32, W/32)

        # Upsample the pool5 score by factor of 2
        ''' x: [N, 32, 24, 30] '''
        x = self.upscore_pool5(x) # (N, num_classes, H/16, W/16)
        # Add skip connection from pool4
        ''' score_pool4: [N, 32, 24, 30] '''
        score_pool4 = self.score_pool4(pool4)
        x = x + score_pool4

        # Upsample the skip-connected pool4+pool5 score by factor of 2
        ''' x: [N, 32, 48, 60] '''
        x = self.upscore_pool4(x) # (N, num_classes, H/8, W/8)
        # Add skip connection from pool3
        ''' score_pool3: [N, 32, 48, 60] '''
        score_pool3 = self.score_pool3(pool3)
        x = x + score_pool3

        # Finally, upsample the skip-connected pool3+pool4+pool5 score by factor of 8
        x = self.upscore_pool3(x) # (N, num_classes, H, W)
        ''' x: [N, 32, 384, 480] '''
        # The output tensor now has the same spatial dimensions as the input
        return x
        

In [10]:
fcn_8s = FCN_8s(num_classes=32).to(DEVICE)

Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth
100%|██████████| 528M/528M [00:02<00:00, 241MB/s]


# Train

In [11]:
loss_acc_dict = {
    'train_loss_lst' : [],
    'train_acc_lst' : [],
    'val_loss_lst' : [],
    'val_acc_lst' : []
}

In [12]:
import random
import matplotlib.pyplot as plt

def label_to_rgb_tensor(label_tensor: torch.Tensor) -> torch.Tensor:
    height, width = label_tensor.shape
    rgb_image = torch.zeros(3, height, width, dtype=torch.uint8)

    for label, rgb in label2RGB_dict.items():
        mask = (label_tensor == label)
        rgb_image[0][mask] = rgb[0]  # Red
        rgb_image[1][mask] = rgb[1]  # Green
        rgb_image[2][mask] = rgb[2]  # Blue

    return rgb_image

def visualize_segmentation(model, val_loader, device, epoch):
    model.eval()
    batch_idx = random.randint(0, len(val_loader) - 1)
    images, labels = list(val_loader)[batch_idx]
    images = images.to(device)
    labels = labels.to(device)

    with torch.no_grad():
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)

    img_idx = random.randint(0, len(images) - 1)
    img = images[img_idx].cpu().numpy().transpose(1, 2, 0)
    label = labels[img_idx].cpu().numpy()
    pred = preds[img_idx].cpu().numpy()
    pred_rgb = label_to_rgb_tensor(pred).cpu().numpy().transpose(1, 2, 0)

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(img)
    axes[0].set_title('Original Image')
    
    axes[1].imshow(label_to_rgb_tensor(label).permute(1, 2, 0).cpu().numpy())
    axes[1].set_title('Ground Truth')

    axes[2].imshow(pred_rgb)
    axes[2].set_title('Predicted Mask')

    # 이미지 파일 저장
    output_path = f'/kaggle/working/segmentation_epoch_{epoch}.png'
    plt.savefig(output_path)
    plt.close()

    print(f"Segmentation visualization saved at {output_path}")

In [13]:
from tqdm import tqdm

def train(model, dataloader, optimizer, loss_fn):
    model.train()
    train_loss, train_acc = 0.0, 0.0
    total_pixels, correct_pixels = 0, 0

    for images, label_images in tqdm(dataloader):
        images = images.to(DEVICE)
        label_images = label_images.to(DEVICE)

        optimizer.zero_grad()
        y_logits = model(images)
        loss = loss_fn(y_logits, label_images)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        # y_preds: (N, H, W)
        y_preds = torch.argmax(y_logits, axis=1) # argmax along channels (N, C, H, W)
        correct_pixels += (label_images == y_preds).sum().item()
        total_pixels += label_images.numel()

    # Calculate average loss and accuracy for the batch
    train_loss /= len(dataloader)
    train_acc = 100 * (correct_pixels / total_pixels)

    return train_loss, train_acc

In [14]:
def evaluate(model, dataloader, optimizer, loss_fn):
    model.eval()
    val_loss, val_acc = 0.0, 0.0
    correct_pixels, total_pixels = 0, 0

    for images, label_images, in dataloader:
        images = images.to(DEVICE)
        label_images = label_images.to(DEVICE)

        y_logits = model(images)
        loss = loss_fn(y_logits, label_images)

        val_loss += loss.item()
        y_preds = torch.argmax(y_logits, axis=1)
        correct_pixels += (label_images == y_preds).sum().item()
        total_pixels += label_images.numel()

    # Average loss/acc over the batches
    val_loss /= len(dataloader)
    val_acc = 100 * (correct_pixels / total_pixels)
    return val_loss, val_acc

In [15]:
import time
import copy
from torch.optim import lr_scheduler

def train_model(model, 
                train_dataloader, 
                val_dataloader,
                optimizer,
                loss_fn,
                num_epochs=1):
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(1, num_epochs + 1):
        start = time.time()
        # Feed forward / backprop on train_dataloader
        train_loss, train_acc = train(model, train_dataloader, optimizer, loss_fn)
        # Feed forward on val_dataloader
        val_loss, val_acc = evaluate(model, val_dataloader, optimizer, loss_fn)

        # Storing epoch histories
        loss_acc_dict['train_loss_lst'].append(train_loss)
        loss_acc_dict['train_acc_lst'].append(train_acc)
        loss_acc_dict['val_loss_lst'].append(val_loss)
        loss_acc_dict['val_acc_lst'].append(val_acc)

        # Update model depending on its peformance on validation data
        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())

        # Scheduler Update
        # scheduler.step()
        
        end = time.time()
        time_elapsed = end - start
        print(f"------------ epoch {epoch} ------------")
        print(f"Train loss: {train_loss:.4f} | Train acc: {train_acc:.2f}%")
        print(f"Validation loss: {val_loss:.4f} | Validation acc: {val_acc:.2f}%")
        print(f"Time taken: {time_elapsed / 60:.0f}min {time_elapsed % 60:.0f}s")

        if (epoch - 1) % 10 == 0:
            visualize_segmentation(model, val_dataloader, DEVICE, epoch - 1)

    
    model.load_state_dict(best_model_wts)
    return model

In [16]:
# loss_fn = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD([
#     {'params': fcn_8s.features.parameters(), 'lr': 1e-4},
#     {'params': fcn_8s.score_pool3.parameters(), 'lr': 1e-3},
#     {'params': fcn_8s.score_pool4.parameters(), 'lr': 1e-3},
#     {'params': fcn_8s.score_fr.parameters(), 'lr': 1e-3},
#     {'params': fcn_8s.upscore_pool3.parameters(), 'lr': 1e-3},
#     {'params': fcn_8s.upscore_pool4.parameters(), 'lr': 1e-3},
#     {'params': fcn_8s.upscore_pool5.parameters(), 'lr': 1e-3}
# ], momentum=0.9, weight_decay=0.0005)
# scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

In [17]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(fcn_8s.parameters(), lr=1e-3)
#scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

In [18]:
fcn8s_trained = train_model(fcn_8s,
                            train_dataloader,
                            val_dataloader,
                            optimizer,
                            loss_fn,
                            num_epochs=30)

100%|██████████| 47/47 [00:38<00:00,  1.23it/s]


------------ epoch 1 ------------
Train loss: 2.3386 | Train acc: 24.31%
Validation loss: 2.0704 | Validation acc: 33.71%
Time taken: 1min 46s
Segmentation visualization saved at /kaggle/working/segmentation_epoch_0.png


100%|██████████| 47/47 [00:31<00:00,  1.51it/s]


------------ epoch 2 ------------
Train loss: 1.6890 | Train acc: 45.73%
Validation loss: 1.6353 | Validation acc: 51.15%
Time taken: 1min 37s


100%|██████████| 47/47 [00:31<00:00,  1.50it/s]


------------ epoch 3 ------------
Train loss: 1.5612 | Train acc: 51.70%
Validation loss: 1.5063 | Validation acc: 54.05%
Time taken: 1min 37s


100%|██████████| 47/47 [00:31<00:00,  1.51it/s]


------------ epoch 4 ------------
Train loss: 1.5019 | Train acc: 53.26%
Validation loss: 1.5056 | Validation acc: 53.36%
Time taken: 1min 37s


100%|██████████| 47/47 [00:31<00:00,  1.51it/s]


------------ epoch 5 ------------
Train loss: 1.4780 | Train acc: 53.53%
Validation loss: 1.4971 | Validation acc: 53.75%
Time taken: 1min 37s


100%|██████████| 47/47 [00:30<00:00,  1.52it/s]


------------ epoch 6 ------------
Train loss: 1.4481 | Train acc: 54.28%
Validation loss: 1.4495 | Validation acc: 54.91%
Time taken: 1min 37s


100%|██████████| 47/47 [00:30<00:00,  1.52it/s]


------------ epoch 7 ------------
Train loss: 1.4209 | Train acc: 54.94%
Validation loss: 1.5282 | Validation acc: 53.73%
Time taken: 1min 37s


100%|██████████| 47/47 [00:31<00:00,  1.51it/s]


------------ epoch 8 ------------
Train loss: 1.4070 | Train acc: 55.17%
Validation loss: 1.4164 | Validation acc: 55.53%
Time taken: 1min 37s


100%|██████████| 47/47 [00:31<00:00,  1.51it/s]


------------ epoch 9 ------------
Train loss: 1.4114 | Train acc: 55.14%
Validation loss: 1.3953 | Validation acc: 55.99%
Time taken: 1min 37s


100%|██████████| 47/47 [00:31<00:00,  1.51it/s]


------------ epoch 10 ------------
Train loss: 1.3959 | Train acc: 55.77%
Validation loss: 1.4962 | Validation acc: 52.23%
Time taken: 1min 37s


100%|██████████| 47/47 [00:31<00:00,  1.51it/s]


------------ epoch 11 ------------
Train loss: 1.3972 | Train acc: 55.65%
Validation loss: 1.4423 | Validation acc: 55.11%
Time taken: 1min 37s
Segmentation visualization saved at /kaggle/working/segmentation_epoch_10.png


100%|██████████| 47/47 [00:30<00:00,  1.52it/s]


------------ epoch 12 ------------
Train loss: 1.3625 | Train acc: 56.57%
Validation loss: 1.4340 | Validation acc: 54.70%
Time taken: 1min 37s


100%|██████████| 47/47 [00:30<00:00,  1.53it/s]


------------ epoch 13 ------------
Train loss: 1.3590 | Train acc: 56.77%
Validation loss: 1.5217 | Validation acc: 51.68%
Time taken: 1min 37s


100%|██████████| 47/47 [00:30<00:00,  1.52it/s]


------------ epoch 14 ------------
Train loss: 1.3518 | Train acc: 57.16%
Validation loss: 1.4931 | Validation acc: 54.02%
Time taken: 1min 37s


100%|██████████| 47/47 [00:30<00:00,  1.53it/s]


------------ epoch 15 ------------
Train loss: 1.3389 | Train acc: 57.49%
Validation loss: 1.4391 | Validation acc: 55.07%
Time taken: 1min 37s


100%|██████████| 47/47 [00:30<00:00,  1.52it/s]


------------ epoch 16 ------------
Train loss: 1.3352 | Train acc: 57.56%
Validation loss: 1.5830 | Validation acc: 54.07%
Time taken: 1min 37s


100%|██████████| 47/47 [00:30<00:00,  1.52it/s]


------------ epoch 17 ------------
Train loss: 1.3145 | Train acc: 58.35%
Validation loss: 1.4239 | Validation acc: 55.40%
Time taken: 1min 37s


100%|██████████| 47/47 [00:30<00:00,  1.53it/s]


------------ epoch 18 ------------
Train loss: 1.3042 | Train acc: 58.72%
Validation loss: 1.6639 | Validation acc: 50.06%
Time taken: 1min 37s


100%|██████████| 47/47 [00:30<00:00,  1.52it/s]


------------ epoch 19 ------------
Train loss: 1.2802 | Train acc: 59.17%
Validation loss: 1.6420 | Validation acc: 50.19%
Time taken: 1min 37s


100%|██████████| 47/47 [00:30<00:00,  1.52it/s]


------------ epoch 20 ------------
Train loss: 1.2751 | Train acc: 59.96%
Validation loss: 1.4634 | Validation acc: 55.06%
Time taken: 1min 37s


100%|██████████| 47/47 [00:30<00:00,  1.52it/s]


------------ epoch 21 ------------
Train loss: 1.2434 | Train acc: 60.65%
Validation loss: 1.6247 | Validation acc: 53.09%
Time taken: 1min 37s
Segmentation visualization saved at /kaggle/working/segmentation_epoch_20.png


100%|██████████| 47/47 [00:31<00:00,  1.52it/s]


------------ epoch 22 ------------
Train loss: 1.2105 | Train acc: 61.82%
Validation loss: 1.5007 | Validation acc: 53.88%
Time taken: 1min 37s


100%|██████████| 47/47 [00:30<00:00,  1.52it/s]


------------ epoch 23 ------------
Train loss: 1.1854 | Train acc: 62.93%
Validation loss: 1.5258 | Validation acc: 54.60%
Time taken: 1min 37s


100%|██████████| 47/47 [00:31<00:00,  1.50it/s]


------------ epoch 24 ------------
Train loss: 1.1516 | Train acc: 64.35%
Validation loss: 1.5709 | Validation acc: 52.14%
Time taken: 1min 37s


100%|██████████| 47/47 [00:30<00:00,  1.52it/s]


------------ epoch 25 ------------
Train loss: 1.1485 | Train acc: 64.43%
Validation loss: 1.5587 | Validation acc: 53.80%
Time taken: 1min 37s


100%|██████████| 47/47 [00:31<00:00,  1.51it/s]


------------ epoch 26 ------------
Train loss: 1.1118 | Train acc: 66.00%
Validation loss: 1.6554 | Validation acc: 50.32%
Time taken: 1min 37s


100%|██████████| 47/47 [00:31<00:00,  1.51it/s]


------------ epoch 27 ------------
Train loss: 1.0982 | Train acc: 66.67%
Validation loss: 1.6111 | Validation acc: 50.27%
Time taken: 1min 37s


100%|██████████| 47/47 [00:31<00:00,  1.51it/s]


------------ epoch 28 ------------
Train loss: 1.0688 | Train acc: 67.92%
Validation loss: 1.5845 | Validation acc: 53.21%
Time taken: 1min 37s


100%|██████████| 47/47 [00:30<00:00,  1.52it/s]


------------ epoch 29 ------------
Train loss: 1.0369 | Train acc: 68.76%
Validation loss: 1.6235 | Validation acc: 49.17%
Time taken: 1min 37s


100%|██████████| 47/47 [00:31<00:00,  1.50it/s]


------------ epoch 30 ------------
Train loss: 1.0084 | Train acc: 69.97%
Validation loss: 1.6175 | Validation acc: 52.33%
Time taken: 1min 37s


# Saving Results

In [19]:
import pickle as pkl

with open('/kaggle/working/loss_acc_dict.pkl', 'wb') as f:
    pkl.dump(loss_acc_dict, f)

torch.save(fcn8s_trained.state_dict(), '/kaggle/working/fcn8s.pth')