# Training the model on images

Domain adaptation

## Imports

In [1]:
# load custom scripts
from dataset import *
from utils import *
import config 

# import the necessary packages
from albumentations.pytorch import ToTensorV2
from imutils import paths
from skimage import io
from sklearn.model_selection import train_test_split
from torch import nn
from torch import optim
from torch.nn import CrossEntropyLoss 
from torch.optim import Adam
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from tqdm import tqdm
from utils import EarlyStopping
import albumentations as A
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import random
import segmentation_models_pytorch as smp
import time
import torch
import torchmetrics
import os

## Controlling sources of randomness

In [2]:
# Set seed
SEED = 42
seed_all(SEED)

## Setup CUDA

In [3]:
print(f"Is CUDA supported by this system? {torch.cuda.is_available()}")
print(f"[INFO] CUDA version: {torch.version.cuda}")
  
# storing ID of current CUDA device
cuda_id = torch.cuda.current_device()
print(f"[INFO] ID of current CUDA device:{torch.cuda.current_device()}")
        
print(f"[INFO] Name of current CUDA device:{torch.cuda.get_device_name(cuda_id)}")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Is CUDA supported by this system? True
[INFO] CUDA version: 11.1
[INFO] ID of current CUDA device:0
[INFO] Name of current CUDA device:NVIDIA GeForce RTX 3090


## Load images

In [6]:
# load the image and label filepaths in a sorted manner
imagePaths = sorted(list(paths.list_images(config.IMAGE_DATASET_PATH)))
labelPaths = sorted(list(paths.list_images(config.LABEL_DATASET_PATH)))

trainImages_source, testImages, trainLabels_source, testLabels = train_test_split(imagePaths,
                                                                    labelPaths,
                                                                    test_size=0.2,
                                                                    train_size=0.8,
                                                                    random_state=SEED)
testImages, valImages, testLabels, valLabels = train_test_split(testImages, 
                                                                  testLabels,
                                                                  test_size=0.5,
                                                                  train_size=0.5,
                                                                  random_state=SEED)

# get the masks corresponding to the labels
trainMasks = [s.replace('labels', 'masks_9x9') for s in trainLabels_source]
valMasks = [s.replace('labels', 'masks_9x9') for s in valLabels]

# save testing images to disk
print("[INFO] saving testing image paths...")
f = open(config.TEST_PATHS, "w")
f.write("\n".join(testImages))
f.close()

[INFO] saving testing image paths...


In [7]:
imagePaths = sorted(list(paths.list_images("/data/jantina/CoralNet/inference/images/train/")))
labelPaths = sorted(list(paths.list_images("/data/jantina/CoralNet/inference/labels/train/")))

trainImages_target, valImages, trainLabels_target, valLabels = train_test_split(imagePaths,
                                                                  labelPaths,
                                                                  test_size=0.2,
                                                                  train_size=0.8,
                                                                  random_state=SEED)

## Transforms

In [8]:
transform_source = A.Compose([
    A.RandomResizedCrop(width=128, height=128, scale=(0.08, 1.0), ratio=(0.75, 1.33), interpolation=cv2.INTER_NEAREST),    
    A.HorizontalFlip(p=0.5),              
    A.RandomRotate90(p=0.5),
    A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.5),
    A.GaussianBlur(blur_limit=(3, 7), sigma_limit=0, always_apply=False, p=0.5),
    A.CLAHE(p=0.8),
    A.RandomBrightnessContrast(p=0.8),    
    A.RandomGamma(p=0.8),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()]
)

In [9]:
transform_target = A.Compose([
    A.RandomResizedCrop(width=128, height=128, scale=(0.08, 1.0), ratio=(0.75, 1.33), interpolation=cv2.INTER_NEAREST),    
    A.HorizontalFlip(p=0.5),              
    A.RandomRotate90(p=0.5),
    A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.5),
    A.GaussianBlur(blur_limit=(3, 7), sigma_limit=0, always_apply=False, p=0.5),
    A.CLAHE(p=0.8),
    A.RandomBrightnessContrast(p=0.8),    
    A.RandomGamma(p=0.8),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()]
)

## Create dataset

In [10]:
# create the train and validation datasets
startTime = time.time()
train_source = Dataset2(imagePaths=trainImages_source, labelPaths = trainLabels_source, transform=transform_source)
train_target = Dataset2(imagePaths=trainImages_target, labelPaths = trainLabels_target, transform=transform_target)


print(f"[INFO] found {len(train_source)} examples in the source set")
print(f"[INFO] found {len(train_target)} examples in the target set")
endTime = time.time()
print("[INFO] total time taken to load the data: {:.2f}s".format(endTime - startTime))

[INFO] found 6829 examples in the source set
[INFO] found 46 examples in the target set
[INFO] total time taken to load the data: 0.00s


In [11]:
half_batch = config.BATCH_SIZE // 2

In [12]:
# create the training and validation data loaders
source_loader = DataLoader(train_source, shuffle=True, 
                         batch_size=config.BATCH_SIZE, 
                         pin_memory=config.PIN_MEMORY, 
                         num_workers=os.cpu_count(),
                         persistent_workers=True,
                         worker_init_fn=seed_worker)

target_loader = DataLoader(train_target, shuffle=False, 
                       batch_size=config.BATCH_SIZE, 
                       pin_memory=config.PIN_MEMORY, 
                       num_workers=os.cpu_count(),
                       persistent_workers=True,
                       worker_init_fn=seed_worker)

## Model initialization

In [13]:
# import the pretrained model
unet = torch.load("/data/jantina/CoralNet/dataset/output/weighted.pth").to(config.DEVICE)

In [38]:
feature_extractor = unet.encoder
clf = unet.decoder

In [39]:
discriminator = nn.Sequential(
    GradientReversal(),
    nn.Linear(320, 50),
    nn.ReLU(),
    nn.Linear(50, 20),
    nn.ReLU(),
    nn.Linear(20, 1)
).to(config.DEVICE)

In [40]:
# initialize loss function
lossFunc = CrossEntropyLoss(ignore_index=0)

# initialize optimizer
opt = Adam(list(discriminator.parameters()) + list(unet.parameters()))

In [None]:
model = Unet()
Discriminator = MLP(512, 256, n_domain_classes)

for (Img, label, domain_label) in Dataloader:

    prediction = model(Img)


    # img of shape: Bx3x128x128
    encoding = model.encoder(Img)
    # encoding of shape: Bx512x15x15

    encoding_global = encoding.mean(dim=2).mean(dim=3)
    #encoding_global of shape: Bx512

    domain_prediction = Discriminator(encoding_global)

    segmentation_loss = crossentropy(predictiomn, label)
    domain_loss = crossentropy(domain_prediction, domain_label)

    loss = added loss with flipped gradient

In [52]:
for epoch in range(1, 5+1):
    batches = zip(source_loader, target_loader)
    n_batches = min(len(source_loader), len(target_loader))

    total_domain_loss = total_label_accuracy = 0
    
    for (source_x, source_labels), (target_x, target_labels) in tqdm(batches, leave=False, total=n_batches):
            x = torch.cat([source_x, target_x])
            x = x.to(config.DEVICE)
            
            domain_y = torch.cat([torch.ones(source_x.shape[0]),
                                  torch.zeros(target_x.shape[0])])
            domain_y = domain_y.to(config.DEVICE)
            label_y = source_labels.to(config.DEVICE)

            features = feature_extractor(x).view(x.shape[0], -1)
            domain_preds = discriminator(features).squeeze()
            label_preds = clf(features[:source_x.shape[0]])

            domain_loss = lossFunc(domain_preds, domain_y)
            label_loss = lossFunc(label_preds, label_y)
            loss = domain_loss + label_loss

            optim.zero_grad()
            loss.backward()
            optim.step()

            total_domain_loss += domain_loss.item()
            total_label_accuracy += (label_preds.max(1)[1] == label_y).float().mean().item()

    mean_loss = total_domain_loss / n_batches
    mean_accuracy = total_label_accuracy / n_batches
    tqdm.write(f'EPOCH {epoch:03d}: domain_loss={mean_loss:.4f}, '
               f'source_accuracy={mean_accuracy:.4f}')

    torch.save(model.state_dict(), 'trained_models/revgrad.pt')

                                                                                                                           

<class 'torch.Tensor'>




AttributeError: 'list' object has no attribute 'view'