In [19]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import model_UNet
from data_augmentation import augmentation, colorjiter, invert
import matplotlib.pyplot as plt
from preprocessing_1_dataloader import get_data

In [45]:
#Weigth coef for the Unsupervised
def dice_loss(logits, targets): 
  preds_animal = F.softmax(logits, dim=1)
  targets_animal = torch.squeeze(targets)
  preds_animal = preds_animal[:,1,:,:]
  eps = 1e-6
  intersection = (preds_animal * targets_animal).sum()
  dice_coef = (2. * intersection + eps) / ((preds_animal**2).sum() + eps)
  dice_loss = 1 - dice_coef
  return dice_loss

@torch.no_grad()
def wt(rampup_length, current, alpha, wait_period = 5):

  if current < wait_period:
    return 0.0
    
  else:
    if rampup_length == 0:
                return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(alpha * np.exp(-5.0 * phase * phase))


#update the Teacher weigth
@torch.no_grad()
def update_ema_variables(model, ema_model, alpha, global_step): 
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)

@torch.no_grad()
def unsup_loss(z, modelT, imgs, unsup_crit): 

  imgT_aug = augmentation(imgs).type(torch.float32)
  imgT_aug = imgT_aug.to(device)

  # z_bar will act as pseudo-labels
  z_bar = modelT(imgT_aug)
  z_bar = F.softmax(z_bar, dim = 1)
  z_bar_preds = torch.argmax(z_bar, dim=1)

  # Transform z_bar into predictions
  Lu = unsup_crit(z, z_bar_preds).to(device)

  return Lu

@torch.no_grad()
def evaluate_model(model, dataloader, device):
  
  model.eval()
  intersection_total, union_total = 0, 0
  pixel_correct, pixel_count = 0, 0
    
  for data in dataloader:
    imgs, labels = data
    imgs, labels = imgs.to(device), labels.to(device)
    logits = model(imgs)
    preds = torch.argmax(logits, dim=1)
    targets = torch.squeeze(labels)
            
    intersection_total += torch.logical_and(preds, targets).sum()
    union_total += torch.logical_or(preds, targets).sum()
            
    pixel_correct += (preds == targets).sum()
    pixel_count += targets.numel()

  iou = (intersection_total / union_total).item()
  accuracy = (pixel_correct / pixel_count).item()
  
  model.train()
  return accuracy, iou

In [55]:
#### Hyper-Param ####

# device to use
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") #Enable GPU support
print(f'Using device: {device}')

# data parms
supervised_percent = 0.2  # what percent of training is to be labelled
img_resize = 64             # resize all images to this size 
is_only_labelled = False     # Training only on supervised data or not

# model params
depth = 3       # depth of unet

# Training params
batch_size = 32
epochs = 100
ramp_up = 5
consistency = 56
alpha = 0.999
global_step = 0

Using device: mps


In [50]:
#### Initialisation ####
#create 2 network
modelS = model_UNet.UNet(in_channels=3, num_classes=2, depth=depth)
modelS = modelS.to(device)
modelT = model_UNet.UNet(in_channels=3, num_classes=2, depth=depth)
modelT = modelT.to(device)
#create the losses
sup_crit = nn.CrossEntropyLoss().to(device)
unsup_crit = nn.CrossEntropyLoss().to(device)
#optimizer
optimizer = Adam(modelS.parameters())

##data loader
mixed_train_loader, val_loader, test_loader = get_data(supervised_percent,1-supervised_percent,0.2,0.1, batch_size=batch_size, img_resize=img_resize, is_mixed_loader=is_only_labelled)

  init.xavier_normal(m.weight)
  init.constant(m.bias, 0)


In [56]:
# Train
eval_freq = 1
losses, accs, IOUs = [], [], []
optimizer = Adam(modelS.parameters())


for epoch in range(epochs):

        modelS.train()
        running_loss = 0

        w_t = wt(ramp_up, epoch, consistency)

        for step, data in enumerate(mixed_train_loader):

            imgs, labs = data
            # Augment images
            imgS_aug = augmentation(imgs)

            imgS_aug = imgS_aug.to(device)
            labs = labs.squeeze().type(torch.LongTensor).to(device)

            optimizer.zero_grad()

            # Forward pass for student and teacher
            z = modelS(imgS_aug) 

            # Find img with label
            sup_idx = torch.tensor([(elem != -1).item() for elem in labs[:, 0, 0]]).to(device) #If batchsize is the first dim

            assert len(sup_idx) != 0

            # Calculate losses
            Ls = sup_crit(z[sup_idx], labs[sup_idx])
            Lu = unsup_loss(z, modelT, imgs, unsup_crit)
  
            loss = Ls + w_t * Lu
            
            loss.backward()
            
            optimizer.step()    
            global_step += 1
            update_ema_variables(modelS, modelT, alpha, global_step)
            running_loss += loss.item()

            #optimizer.param_groups[0]['lr'] = lr(epoch+2)
            #print(loss.item())

        print(f'Epoch: {epoch + 1:4d} - Loss: {running_loss:6.2f}')
        losses.append(running_loss)

        if (epoch % eval_freq == 0):
          accuracy, IOU = evaluate_model(modelS, val_loader, device)
          accs.append(accuracy)
          IOUs.append(IOU)
          print(f'accuracy: {accuracy:2.0%}; IOU: {IOU:2.0%}')

np.savetxt("losses", losses)
np.savetxt("accs", accs)
np.savetxt("IOUs", IOUs)

Epoch:    1 - Loss:  69.97
accuracy: 82%; IOU: 63%
Epoch:    2 - Loss:  67.06
accuracy: 82%; IOU: 67%
Epoch:    3 - Loss:  65.43
accuracy: 84%; IOU: 69%
Epoch:    4 - Loss:  62.54
accuracy: 82%; IOU: 61%
Epoch:    5 - Loss:  61.81
accuracy: 85%; IOU: 71%
Epoch:    6 - Loss: 2520.15
accuracy: 85%; IOU: 69%
Epoch:    7 - Loss: 2533.57
accuracy: 86%; IOU: 72%
Epoch:    8 - Loss: 2432.82
accuracy: 87%; IOU: 73%
Epoch:    9 - Loss: 2453.86
accuracy: 87%; IOU: 72%
Epoch:   10 - Loss: 2382.75
accuracy: 87%; IOU: 73%
Epoch:   11 - Loss: 2394.56
accuracy: 87%; IOU: 72%
Epoch:   12 - Loss: 2365.74
accuracy: 87%; IOU: 73%
Epoch:   13 - Loss: 2363.22
accuracy: 87%; IOU: 71%
Epoch:   14 - Loss: 2290.71
accuracy: 87%; IOU: 75%
Epoch:   15 - Loss: 2265.78
accuracy: 87%; IOU: 73%
Epoch:   16 - Loss: 2241.05
accuracy: 87%; IOU: 71%
Epoch:   17 - Loss: 2218.72
accuracy: 87%; IOU: 74%
Epoch:   18 - Loss: 2261.46
accuracy: 88%; IOU: 75%
Epoch:   19 - Loss: 2197.85
accuracy: 88%; IOU: 75%
Epoch:   20 - Los

KeyboardInterrupt: 