In [None]:
!pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1
!pip install torchmetrics

In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
os.chdir("/content/drive/MyDrive/DL1008_Final Competition")
data_folder = os.path.join("/content/drive/MyDrive/DL1008_Final Competition/Dataset_Student/train")
print(data_folder)

**Getting Standardization Parameters**

In [None]:
from unet_model import UNet

In [None]:
import sys
import torch
import numpy as np
import torchmetrics
from torch import nn, optim
from torch.utils.data import random_split,Dataset,DataLoader
from torchvision import tv_tensors
from torchvision.transforms import v2
import random
import torch.nn.functional as F
import PIL.Image

In [None]:
class Semantic_Segmentation_Dataset(Dataset):

    def __init__(self, root_dir,input_transform = None,target_transform = None):
        self.root_dir= root_dir
        self.input_transform = input_transform
        self.target_transform = target_transform

    def __len__(self):
        list_of_folders = [subdir for subdir in os.listdir(self.root_dir) if os.path.isdir(os.path.join(self.root_dir,subdir))]
        image_count = len(os.listdir(os.path.join(self.root_dir,list_of_folders[0]))) - 1
        total_folders = len(list_of_folders)
        return total_folders*image_count


    def transform(self, image,mask):
        if self.input_transform:
            image = self.input_transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)

        identical_transform = v2.Compose([
            v2.ToImage(),
            v2.RandomCrop(size=(160,240),padding=(40,60),padding_mode='edge'),
            v2.RandomHorizontalFlip(p=.75),
            v2.RandomVerticalFlip(p=.75),
            v2.RandomRotation(degrees=70)
        ])

        new_img,new_msk = identical_transform((image,mask))
        return new_img,new_msk

    def __getitem__(self, idx):
        base, idx = divmod(idx,22)
        folder_path = os.path.join(self.root_dir,f"video_{base}")

        image,target = split_video_mask(folder_path,index=idx)


        if self.input_transform:
            image = self.input_transform(image)
        if self.target_transform:
            target = self.target_transform(target)

        image,target = self.transform(image,target)

        return image, target

**Loading Data**

In [None]:
def split_video_mask(video_folder,index):
    '''split_video_mask returns a tuple of tv_tensors

    :param video_folder: directory to video folder
    :type video_folder: path
    :return: tuple of form (pytorch tensor, pytorch tensor) (size [3, 160, 240],[160, 240])
    '''
    directory = video_folder

    img_path = os.path.join(video_folder,f"image_{index}.png")
    mask_path = os.path.join(video_folder,"mask.npy")
    assert os.path.isfile(img_path) and os.path.isfile(mask_path)

    img = tv_tensors.Image(PIL.Image.open(os.path.join(video_folder,f"image_{index}.png"))).to(torch.get_default_dtype())
    all_masks = torch.from_numpy(np.load(os.path.join(directory,"mask.npy"))).to(torch.get_default_dtype())
    mask = tv_tensors.Image(all_masks[index])

    return img, mask

In [None]:
'''
sample_mean = torch.zeros(3)
sample_var = torch.zeros(3)

for subdir in random.sample(os.listdir(data_folder),60): #about .03 off from true variance
    mask_list = split_video_mask(os.path.join(data_folder,subdir))
    for image,_ in mask_list:
      sample_mean += torch.mean(torch.Tensor(image),dim=[1,2])
      sample_var += torch.var(torch.Tensor(image),dim=[1,2])

    sample_mean = sample_mean/len(mask_list)
    sample_var = sample_var/len(mask_list)

sample_mean = sample_mean/60
sample_std = torch.sqrt(sample_var/60)

mean,std =
#tensor([2.2536, 2.2524, 2.2107]), tensor([1.8866, 1.6751, 2.2449])
'''
sample_mean = torch.tensor([2.2536, 2.2524, 2.2107])
sample_std = torch.tensor([1.8866, 1.6751, 2.2449])

In [None]:
print(sample_mean,sample_std)

tensor([2.2536, 2.2524, 2.2107]) tensor([1.8866, 1.6751, 2.2449])


In [None]:
def load_data(video_folder):

    input_transform = v2.Normalize(sample_mean,sample_std)


    ss_dataset = Semantic_Segmentation_Dataset(video_folder,
                                            input_transform=input_transform,
                                            target_transform=None
                                            )

    generator = torch.Generator().manual_seed(10)

    train_set, test_set = random_split(ss_dataset,[.7, .3],generator=generator)

    return train_set, test_set

In [None]:
def train_model(model, device,criterion,optimizer,train_loader,epoch,hyperparameters):
    model.train()
    print('Training')
    global best_val_jac
    batch_loss = 0
    jaccard = torchmetrics.JaccardIndex(task='multiclass',num_classes=49).to(device)
    height = 160
    width = 240
    correct_pixel = 0

    for idx, batch in enumerate(train_loader):
        correct_pixel = 0
        data = batch[0].to(device)
        labels = batch[1].to(device).squeeze()

        optimizer.zero_grad()
        predicted = model(data)
        loss = criterion(predicted,labels.long()).requires_grad_(True)
        batch_loss += loss.item()
        loss.backward()
        optimizer.step()

        predicted = torch.argmax(predicted,dim=1)


        if idx % 100 == 0:
            print(f"Epoch:{epoch}-{idx/(15400//hyperparameters['batch_size']):.2f}, loss: {loss.item():.2f}, jaccard: {100*jaccard(predicted,labels):.2f}%")

    batch_loss = batch_loss/idx
    torch.save({'epoch': epoch+1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': batch_loss,
        'best_val_jac':best_val_jac,
        }, "mask_model_stop.pth")



In [None]:
def validate_model(model, device,criterion,optimizer,val_loader,epoch,hyperparameters):
    # Validation loss
    print('Validating')
    global best_val_jac
    val_jac = 0
    total_pixels = 0
    correct_pixel = 0
    jaccard = torchmetrics.JaccardIndex(task='multiclass',num_classes=49).to(device)
    height = 160
    width = 240
    count = 0

    for idx,batch in enumerate(val_loader):
        model.eval()
        with torch.no_grad():
            data = batch[0].to(device)
            labels = batch[1].to(device).squeeze()

            predicted = model(data)
            #predicted = torch.argmax(predicted,dim=1)       #(Shape N,H,W)

            #print(f"Epoch: {epoch} Validation CEloss: {loss.item()}")
            val_jac += 100 * jaccard(predicted,labels)
            predicted = torch.argmax(predicted,dim=1)



        if idx % 100 == 0:
          count += 1
          print(f"Step: {100*idx/206:.2f}")
          for k in range(hyperparameters["batch_size"]):
                for i in range(height):
                    for j in range(width):
                        if predicted[k][i][j].item() == labels[k][i][j].item():
                            correct_pixel += 1
    val_jac = val_jac/idx
    # Save the best model

    if val_jac > best_val_jac:
        best_val_jac = val_jac
        torch.save({'epoch': epoch+1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss':batch_loss,
        'best_val_jac':best_val_jac,}, "best_mask_model.pth")

    total_pixels = count * height * width *hyperparameters["batch_size"]
    if epoch % 1 == 0:  #Running TR loss: {running_loss:.2f}
        print(f"Epoch: {epoch}, Validation Accuracy (Jaccard): {val_jac:.2f}%, Pixel Accuracy: {100*correct_pixel/total_pixels:.2f}%")

In [None]:
    global best_val_jac
    global batch_loss
    best_val_jac = 0
    batch_loss = 0

In [None]:
def train_practice(hyperparameters, train_subset,val_subset):
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"

    net = UNet(n_channels=3,n_classes=49,bilinear=False).to(device)
    optimizer = optim.Adam(net.parameters(),lr=hyperparameters["lr"])

    if os.path.isfile("./mask_model_stop.pth") and os.path.getsize("./mask_model_stop.pth") > 0:
        checkpoint = torch.load("./mask_model_stop.pth")
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        batch_loss = checkpoint['loss']
        best_val_jac = checkpoint['best_val_jac']
        print(f"Resuming training from epoch: {start_epoch}")
    else:
        best_val_jac = 0
        start_epoch = 0
        batch_loss = 0


    print(f"Training on {device}")
    net.to(device)
    jaccard = torchmetrics.JaccardIndex(task="multiclass", num_classes=49).to(device).requires_grad_(True)
    #criterion = 1-jaccard.requires_grad_(True)
    criterion = nn.CrossEntropyLoss().to(device)
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")


    trainloader = DataLoader(
        train_subset, batch_size=int(hyperparameters["batch_size"]), shuffle=True, drop_last=True,num_workers=2
    )
    valloader = DataLoader(
        val_subset, batch_size=int(hyperparameters["batch_size"]),num_workers=2,drop_last=True
    )

    for epoch in range(start_epoch,hyperparameters["max_epochs"]):
        train_model(net,device,criterion,optimizer,trainloader,epoch,hyperparameters)
        validate_model(net,device,criterion,optimizer,valloader,epoch,hyperparameters)

In [None]:
train_subset, val_subset = load_data(data_folder)
#train_subset = torch.utils.data.Subset(train_subset,range(12000))
#val_subset = torch.utils.data.Subset(val_subset,range(300))
print(len(train_subset),len(val_subset))

15400 6600


In [None]:
hyperparameters = {"lr":0.001,"batch_size":32,"max_epochs":50}

model_exists = os.path.isfile("./best_mask_model.pth") and os.path.getsize("./best_mask_model.pth") > 0
threshold = 90

if model_exists:
    #load model
    checkpoint = torch.load("./mask_model_stop.pth")

    #Loading Net
    net = UNet(n_channels=3,n_classes=49,bilinear=False)
    optimizer = optim.Adam(net.parameters(),lr=hyperparameters["lr"])
    net.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    #Check best accuracy
    best_val_jac = checkpoint['best_val_jac']
    model_below_threshold = best_val_jac < threshold
    print(best_val_jac,model_below_threshold)

    if model_below_threshold:   #train more
        train_practice(hyperparameters,train_subset,val_subset)
    else: #model above threshold
        pass
else: #no model
    train_practice(hyperparameters,train_subset,val_subset)

#Model is Trained
net = UNet(n_channels=3,n_classes=49,bilinear=False)
checkpoint = torch.load("./best_mask_model.pth")
#net.classifier._modules['4'] = torch.nn.Conv2d(256, 49, kernel_size=(1, 1), stride=(1, 1))
optimizer = optim.Adam(net.parameters(),lr=hyperparameters["lr"])
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
best_val_jac = checkpoint["best_val_jac"]
print(f"Final validation accuracy (Jaccard): {best_val_jac:.2f}%")

tensor(87.1530, device='cuda:0') tensor(True, device='cuda:0')
Resuming training from epoch: 30
Training on cuda:0
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Training
Epoch:30-0.00, loss: 0.01, jaccard: 86.22%
Epoch:30-0.21, loss: 0.01, jaccard: 86.28%
Epoch:30-0.42, loss: 0.01, jaccard: 84.68%
Epoch:30-0.62, loss: 0.01, jaccard: 87.05%
Epoch:30-0.83, loss: 0.01, jaccard: 86.83%
Validating
Step: 0.00
Step: 48.54
Step: 97.09
Epoch: 30, Validation Accuracy (Jaccard): 86.24%, Pixel Accuracy: 99.63%
Training
Epoch:31-0.00, loss: 0.01, jaccard: 89.66%
Epoch:31-0.21, loss: 0.01, jaccard: 85.19%
Epoch:31-0.42, loss: 0.01, jaccard: 82.60%
Epoch:31-0.62, loss: 0.01, jaccard: 86.72%
Epoch:31-0.83, loss: 0.01, jaccard: 86.66%
Validating
Step: 0.00
Step: 48.54
Step: 97.09
Epoch: 31, Validation Accuracy (Jaccard): 86.50%, Pixel Accuracy: 99.68%
Training
Epoch:32-0.00, loss: 0.01, jaccard: 89.47%
Epoch:32-0.21, loss: 0.01, jaccard: 85.52%
Epoch:32-0.42, loss: 0.01, ja

KeyboardInterrupt: ignored

In [None]:
checkpoint = torch.load("./89valjac.pth")
print(checkpoint['best_val_jac'].item())

88.97635650634766


In [None]:
import gc

In [None]:
torch.cuda.empty_cache()
gc.collect()

2853