In [1]:
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

from einops.layers.torch import Rearrange

from timm.data.transforms_factory import create_transform

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split

import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms

from PIL import Image

from eyes.datasets import DannDataset, EyesDataset
from eyes.models.dann import Dann, get_lambda

%matplotlib inline

In [2]:
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
rand_gen = torch.manual_seed(0)
labeled_datafolder = "/home/dima/datasets/eyes/labeled/"
datafolder = "/home/dima/datasets/eyes/train/"
mrl_datafolder = "/home/dima/datasets/eyes/mrl/"

In [3]:
train_transform = transforms.Compose([
    transforms.Grayscale(),
    create_transform(24, is_training=True, auto_augment='rand-m7-n3', mean=(0.5,), std=(0.5,),),
])
test_transform = transforms.Compose([
    transforms.Grayscale(),
    create_transform(24, is_training=False, mean=(0.5,), std=(0.5,),),
])



In [4]:
batch_size = 512
mrl_dataset = ImageFolder(mrl_datafolder, transform=train_transform)
train_size = int(len(mrl_dataset) * 0.9)
train_mrl_dataset, test_mrl_dataset = random_split(mrl_dataset, [train_size, len(mrl_dataset) - train_size])
dataset = EyesDataset(datafolder, transform=test_transform)
train_dataset = DannDataset(train_mrl_dataset, dataset)
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(test_mrl_dataset, batch_size=batch_size, shuffle=False)

In [5]:
backbone = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=3, stride=1),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.Conv2d(64, 256, kernel_size=3, stride=1),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.Conv2d(256, 1024, kernel_size=3, stride=1),
    nn.MaxPool2d(2),
    nn.ReLU(),
    Rearrange("b c h w -> b (c h w)"),
)

cls_head = nn.Sequential(
    nn.LazyLinear(2048),
    nn.ReLU(),
    nn.Linear(2048, 2),
)

domain_head = nn.Sequential(
    nn.LazyLinear(2048),
    nn.ReLU(),
    nn.Linear(2048, 2),
)

model = Dann(backbone, cls_head, domain_head).to(device)



In [6]:
criterion = nn.CrossEntropyLoss()
optimizer_B = optim.AdamW(backbone.parameters(), lr=5e-4)
optimizer_C = optim.AdamW(cls_head.parameters(), lr=5e-4)
optimizer_D = optim.AdamW(domain_head.parameters(), lr=5e-4)

In [7]:
max_epochs = 4
for epoch in range(max_epochs):
    running_cls_loss = running_domain_loss = running_acc = 0.
    for i, data in enumerate(trainloader):
        (source, labels), target = data
        source, labels, target = source.to(device), labels.to(device), target.to(device)
        
        features = model.get_features(torch.cat([source, target], dim=0))
        
        domain_labels = torch.cat([
            torch.zeros(len(labels), device=device, dtype=torch.long), 
            torch.ones(len(labels), device=device, dtype=torch.long)
        ])
        domain_output = model.predict_domain(features.detach())
        domain_loss = criterion(domain_output, domain_labels)
        
        optimizer_D.zero_grad()
        domain_loss.backward()
        optimizer_D.step()
        
        cls_output = model.predict_class(features[:len(labels)])
        domain_output = model.predict_domain(features)
        cls_loss = criterion(cls_output, labels)
        domain_loss = criterion(domain_output, domain_labels)
        lambda_coef = 0.1 * get_lambda(epoch, max_epochs)
        loss = cls_loss - lambda_coef * domain_loss

        optimizer_B.zero_grad()
        optimizer_C.zero_grad()
        
        loss.backward()
        
        optimizer_B.step()
        optimizer_C.step()
        
        _, predicted = torch.max(cls_output.data, 1)
        total = labels.size(0)
        correct = (predicted.cpu() == labels.cpu()).sum().item()
        acc = correct / total
        
        running_cls_loss += cls_loss.item()
        running_domain_loss += domain_loss.item()
        running_acc += acc
        if i % 10 == 9:
            print('[%d, %5d] cls loss: %.3f domain loss: %.3f acc: %.3f' %
                  (epoch + 1, i + 1, running_cls_loss / 10, running_domain_loss / 10, running_acc / 10))
            running_cls_loss = running_domain_loss = running_acc = 0.

print('Finished Training')

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


[1,    10] cls loss: 0.688 domain loss: 0.666 acc: 0.538
[1,    20] cls loss: 0.669 domain loss: 0.570 acc: 0.602
[1,    30] cls loss: 0.649 domain loss: 0.503 acc: 0.622
[1,    40] cls loss: 0.628 domain loss: 0.483 acc: 0.643
[1,    50] cls loss: 0.606 domain loss: 0.467 acc: 0.662
[1,    60] cls loss: 0.593 domain loss: 0.422 acc: 0.671
[1,    70] cls loss: 0.560 domain loss: 0.389 acc: 0.700
[1,    80] cls loss: 0.537 domain loss: 0.367 acc: 0.720
[1,    90] cls loss: 0.523 domain loss: 0.344 acc: 0.727
[1,   100] cls loss: 0.517 domain loss: 0.313 acc: 0.723
[1,   110] cls loss: 0.478 domain loss: 0.294 acc: 0.759
[2,    10] cls loss: 0.466 domain loss: 1.720 acc: 0.766
[2,    20] cls loss: 0.907 domain loss: 9.788 acc: 0.627
[2,    30] cls loss: 0.912 domain loss: 3.661 acc: 0.557
[2,    40] cls loss: 0.692 domain loss: 1.609 acc: 0.621
[2,    50] cls loss: 0.595 domain loss: 1.101 acc: 0.661
[2,    60] cls loss: 0.567 domain loss: 0.874 acc: 0.691
[2,    70] cls loss: 0.535 doma