# Domain adaptation

In [None]:
DATASET_ZIP_PATH = '/tmp/Adaptiope.zip'
DATASET_EXTRACTION_PATH = '/tmp/Adaptiope'
DATASET_PATH = Path('./data/adaptiope_small')

## Dataset extraction

In [None]:
from os import makedirs
from os.path import join
from shutil import copytree
from pathlib import Path

In [None]:
!mkdir -p {DATASET_EXTRACTION_PATH}
!unzip -d {DATASET_EXTRACTION_PATH} {DATASET_ZIP_PATH}

In [None]:
classes = ["backpack", "bookcase", "car jack", "comb", "crown", "file cabinet", "flat iron", "game controller", "glasses",
           "helicopter", "ice skates", "letter tray", "monitor", "mug", "network switch", "over-ear headphones", "pen",
           "purse", "stand mixer", "stroller"]
for d, td in zip([
    f"{DATASET_EXTRACTION_PATH}/Adaptiope/product_images",
    f"{DATASET_EXTRACTION_PATH}/Adaptiope/real_life"],[
    f"{DATASET_PATH}/product_images",
    f"{DATASET_PATH}/real_life"]):
    makedirs(td)
    for c in classes:
        c_path = join(d, c)
        c_target = join(td, c)
        copytree(c_path, c_target)

## Dataset exploration

In [None]:
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt

In [None]:
dataset = ImageFolder(DATASET_PATH / "product_images")
idx_to_class = {v: k for k,v in dataset.class_to_idx.items()}

In [None]:
# imgs: List[str, int] path, class
seen_classes = set()
imgs = []
for i, (p, c) in enumerate(dataset.imgs):
    if c not in seen_classes:
        seen_classes.add(c)
        imgs.append(i)

In [None]:
fig, axs = plt.subplots(2, 5, figsize=(10,10))
for i in range(2):
    for j in range(5):
        image, title = dataset[imgs.pop(0)]
        axs[i,j].imshow(image)
        axs[i,j].set_title(idx_to_class[title])

## Utility functions

In [None]:
import random
import numpy as np
import torch


def set_random_seed(seed=0) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_device() -> torch.device:
    return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import models


class CustomModel(nn.Module):
    def __init__(self, num_classes: int, feature_dimension: int, dropout_rate: float) -> None:
        super().__init__()
        self.feature_extractor = models.resnet34(pretrained=True)
        self.feature_extractor.avgpool = nn.AdaptiveAvgPool2d(1)
        self.dropout_rate = dropout_rate

        out_feature_extractor = self.feature_extractor.fc.in_features

        for param in self.feature_extractor.parameters():
            param.requires_grad = False

        self.feature_extractor.fc = nn.Sequential(
            nn.Linear(out_feature_extractor, feature_dimension),
            nn.ReLU(),
            nn.Linear(feature_dimension, feature_dimension // 2),
            nn.ReLU(),
            nn.Linear(feature_dimension // 2, num_classes),
        )

        init_modules = [
            self.feature_extractor.fc,
        ]

        for m in init_modules:
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight.data)
                if m.bias is not None:
                    nn.init.constant_(m.bias.data, 0)

    def forward(self, x):
        y = self.feature_extractor(x)

        return y

## Dataset

In [None]:
from torchvision.datasets import ImageFolder
from torchvision import transforms as T

from torch.utils.data import DataLoader

In [None]:
train = ImageFolder(
    DATASET_PATH / "product_images",
    transform=T.Compose(
        [
            T.Resize(224),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
)
validation = ImageFolder(
    DATASET_PATH / "product_images",
    transform=T.Compose(
        [
            T.Resize(224),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
)
test = ImageFolder(
    DATASET_PATH / "product_images",
    transform=T.Compose(
        [
            T.Resize(224),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
)

## Train

In [None]:
from torch.optim import lr_scheduler, SGD, Adam

from tqdm import tqdm

In [None]:
# misc
set_random_seed(33)
device = get_device()
num_threads = 16

# train
num_epochs = 10
batch_size = 32
lr = 0.0001
weight_decay = 0
scheduler_step_size = 5
scheduler_gamma = 0.2

In [None]:
def get_data_loader(dataset, batch_size, num_threads):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_threads
    )

In [None]:
model = CustomModel(len(train.classes), 128, 0.2)
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
scheduler = lr_scheduler.StepLR(
    optimizer,
    step_size=scheduler_step_size,
    gamma=scheduler_gamma,
)
dataloaders = {
    'train': get_data_loader(train, batch_size, num_threads),
    'validation': get_data_loader(validation, batch_size, num_threads),
    'test': get_data_loader(test, batch_size, num_threads)
}

In [None]:
best_model = model
best_loss = np.Inf
with tqdm(total=num_epochs) as pbar:
    for epoch in range(num_epochs):
        pbar.set_description(f"Epoch {epoch}")
        phases_loss = {}
        for phase, dataloader in dataloaders.items():
            if phase == "train":
                model.train()
            else:
                model.eval()

            epoch_loss = 0.0
            for index, (x, labels) in enumerate(dataloader):
                x = x.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()

                with torch.autocast(device_type=device.type):
                    with torch.set_grad_enabled(phase == "train"):
                        predictions = model(x)
                        loss = criterion(predictions, labels)
                        epoch_loss += loss.item()

                        if phase == "train":
                            loss.backward()
                            optimizer.step()
                            scheduler.step()

            epoch_loss /= len(dataloader)
            phases_loss[phase] = epoch_loss
            
            if phase == "validation" and abs(epoch_loss) <= abs(best_loss):
                best_model = model
                best_loss = epoch_loss
        pbar.set_postfix(phases_loss)
        pbar.update(1)