**Install Requirements**

In [0]:
"""!pip3 install 'torch==1.4.0'
!pip3 install 'torchvision==0.5.0'
!pip3 install 'Pillow-SIMD'
!pip3 install 'tqdm'"""

"!pip3 install 'torch==1.4.0'\n!pip3 install 'torchvision==0.5.0'\n!pip3 install 'Pillow-SIMD'\n!pip3 install 'tqdm'"

In [0]:
import os
"""if not os.path.isdir('./Pacs'):
  !git clone https://github.com/lore-lml/machine-learning2020-hw3.git
  !mv 'machine-learning2020-hw3' 'Pacs'
  !rm './Pacs/hw3.ipynb'
  !rm './Pacs/README.md'"""

import logging
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from torch.backends import cudnn

import torchvision
from torchvision import transforms
from pacs_dataset import Pacs
from dann import alexdann
from dann import train_src, test_target, dann_train_src_target

from PIL import Image
from tqdm import tqdm

from sklearn.model_selection import ParameterGrid
import matplotlib.pyplot as plt
%matplotlib inline

**Set Arguments**

In [0]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

NUM_CLASSES = 7

BATCH_SIZE = 128      # Higher batch sizes allows for larger learning rates. An empirical heuristic suggests that, when changing
                      # the batch size, learning rate should change by the same factor to have comparable results

LR = 1e-3             # The initial Learning Rate
MOMENTUM = 0.9        # Hyperparameter for SGD, keep this at 0.9 when using SGD
WEIGHT_DECAY = 5e-5   # Regularization, you can keep this at the default

NUM_EPOCHS = 10       # Total number of training epochs (iterations over dataset)
STEP_SIZE = 3         # How many epochs before decreasing learning rate (if using a step-down policy)
GAMMA = 0.5           # Multiplicative factor for learning rate step-down

ALPHA = 'dynamic'
BASE_FILE_PATH = "RUN_1_LR1e-3_DynAlpha_SGD_SS2_GAMMA03"

**Define Data Preprocessing**

In [0]:
transforms = transforms.Compose([transforms.Resize(256),      # Resizes short size of the PIL image to 256
                                      transforms.CenterCrop(224),  # Crops a central square patch of the image
                                                                   # 224 because torchvision's AlexNet needs a 224x224 input!
                                                                   # Remember this when applying different transformations, otherwise you get an error
                                      transforms.ToTensor(), # Turn PIL Image to torch.Tensor
                                      transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # Normalizes tensor with mean and standard deviation
])

**Prepare Dataset**

In [0]:
ROOT = 'Pacs/PACS'

source_data = Pacs(ROOT, transform=transforms, source='photo')
cartoon_data = Pacs(ROOT, transform=transforms, source='cartoon')
sketch_data = Pacs(ROOT, transform=transforms, source='sketch')
target_data = Pacs(ROOT, transform=transforms, source='art_painting')

_, source_labels = source_data.get_img_with_labels()
_, cartoon_labels = cartoon_data.get_img_with_labels()
_, sketch_labels = sketch_data.get_img_with_labels()
_, target_labels = target_data.get_img_with_labels()

print(f"# classes source_data: {len(set(source_labels))}")
print(f"# classes cartoon_data: {len(set(cartoon_labels))}")
print(f"# classes sketch_data: {len(set(sketch_labels))}")
print(f"# classes target_data: {len(set(target_labels))}")
print(f"source_data: {len(source_data)} elements")
print(f"cartoon_data: {len(cartoon_data)} elements")
print(f"sketch_data: {len(sketch_data)} elements")
print(f"target_data: {len(target_data)} elements")

# classes source_data: 7
# classes cartoon_data: 7
# classes sketch_data: 7
# classes target_data: 7
source_data: 1670 elements
cartoon_data: 2344 elements
sketch_data: 3929 elements
target_data: 2048 elements


**Prepare Dataloaders**

In [0]:
def get_data_loaders(dataset, batch_size=BATCH_SIZE):
  return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True), DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, drop_last=False)
"""
# Photo Source Data Loader
source_dataloader = DataLoader(source_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

# Cartoon and sketch validation Data Loaders
cartoon_dataloader = DataLoader(cartoon_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
cartoon_test_dataloader = DataLoader(cartoon_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, drop_last=False)
sketch_dataloader = DataLoader(sketch_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
sketch_test_dataloader = DataLoader(sketch_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, drop_last=False)

# Art_painting target Data Loader
target_dataloader = DataLoader(target_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
test_dataloader = DataLoader(target_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, drop_last=False)"""

'\n# Photo Source Data Loader\nsource_dataloader = DataLoader(source_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)\n\n# Cartoon and sketch validation Data Loaders\ncartoon_dataloader = DataLoader(cartoon_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)\ncartoon_test_dataloader = DataLoader(cartoon_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, drop_last=False)\nsketch_dataloader = DataLoader(sketch_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)\nsketch_test_dataloader = DataLoader(sketch_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, drop_last=False)\n\n# Art_painting target Data Loader\ntarget_dataloader = DataLoader(target_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)\ntest_dataloader = DataLoader(target_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, drop_last=False)'

**Prepare Network**

In [0]:
def init_cnn_objects(model, lr=LR, step_size=STEP_SIZE, gamma=GAMMA):
  
  # Define loss function
  criterion_1 = nn.CrossEntropyLoss() # for classification, we use Cross Entropy
  criterion_2 = nn.CrossEntropyLoss()
  parameters_to_optimize = model.parameters() # In this case we optimize over all the parameters of AlexNet
  
  optimizer = optim.SGD(parameters_to_optimize, lr=lr, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
  #optimizer = optim.Adam(parameters_to_optimize, lr=lr,amsgrad=True)
  #optimizer = optim.AdamW(parameters_to_optimize, lr=lr,amsgrad=True, weight_decay=WEIGHT_DECAY)
  scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

  return criterion_1, criterion_2, optimizer, scheduler

**Training**

In [0]:
def simple_train_test(model, source_dataloader, test_dataloader, criterion, optimizer, scheduler, max_epoch=NUM_EPOCHS, file_path=BASE_FILE_PATH, device=DEVICE):
    train_losses = []
    loss_min = -1
    accuracy_max = 0
    
    model = model.to(device)
    cudnn.benchmark
    
    current_step = 0
    for epoch in range(max_epoch):
        print('Starting epoch {}/{}, LR = {}'.format(epoch+1, max_epoch, scheduler.get_lr()))
        cumulative_loss, current_step = train_src(model, source_dataloader, optimizer, criterion, current_step, device)
        curr_loss = cumulative_loss / len(source_dataloader)
        train_losses.append(curr_loss)
        """if loss_min == -1 or loss_min > curr_loss:
            loss_min = curr_loss
            torch.save(model, f"{file_path}_best_model.pth")"""
        curr_accuracy = test_target(model, test_dataloader, criterion, device) / float(len(target_data))
        print(f"Current Accuracy: {curr_accuracy}")
        if accuracy_max < curr_accuracy:
          accuracy_max = curr_accuracy
          torch.save(model, f"{file_path}_best_model.pth")
        scheduler.step()
        
    # model = torch.load(f"{file_path}_best_model.pth").to(device)
    #accuracy = test_target(model, test_dataloader, criterion, device) / float(len(target_data))
    #print(f"Accuracy on test set: {accuracy}%")
    return train_losses, accuracy_max

In [0]:
params = {
    "lr": [1e-3, 5e-3],
    "batch_size": [128],
    "epochs": [10],
    "step_size": [2, 3],
    "gamma": [0.5],
    "alpha": ['dynamic']
}

grid = ParameterGrid(params)
results = []
for config in grid:
  source_dataloader, _ = get_data_loaders(source_data, config['batch_size'])
  _, cartoon_test_dataloader = get_data_loaders(cartoon_data, config['batch_size'])
  _, sketch_test_dataloader = get_data_loaders(sketch_data, config['batch_size'])
  no_dann_cartoon = alexdann(pretrained=True)
  criterion_cartoon, _, optimizer_cartoon, scheduler_cartoon = init_cnn_objects(no_dann_cartoon, lr=config['lr'], 
                                                                          step_size=config['step_size'], gamma=config['gamma'])
  no_dann_sketch = alexdann(pretrained=True)
  criterion_sketch, _, optimizer_sketch, scheduler_sketch = init_cnn_objects(no_dann_sketch, lr=config['lr'], 
                                                                          step_size=config['step_size'], gamma=config['gamma'])

  cartoon_losses, cartoon_accuracy_max = simple_train_test(no_dann_cartoon, source_dataloader, cartoon_test_dataloader, 
                                                           criterion_cartoon, optimizer_cartoon, scheduler_cartoon, max_epoch=config['epochs'])
  sketch_losses, sketch_accuracy_max = simple_train_test(no_dann_sketch, source_dataloader, sketch_test_dataloader, 
                                                         criterion_sketch, optimizer_sketch, scheduler_sketch, max_epoch=config['epochs'])

  curr_result = {'params': config, 'cartoon_losses': cartoon_losses, 'sketch_losses': sketch_losses, 
                 'best_acc': np.mean([cartoon_accuracy_max, sketch_accuracy_max])}
  results.append(curr_result)
  
best_conf = max(results, key=lambda x: x['best_acc'])
print(f"Best configuration is:\n{best_conf['params']}")
print(f"Highest mean accuracy: {best_conf['best_acc']}")


Starting epoch 1/10, LR = [0.001]




Step 0, Loss_train 2.2247133255004883
Step 10, Loss_train 0.3557354509830475


100%|██████████| 19/19 [00:04<00:00,  3.97it/s]


Current Accuracy: 0.27099609375
Starting epoch 2/10, LR = [0.001]
Step 20, Loss_train 0.10233084112405777


100%|██████████| 19/19 [00:04<00:00,  3.91it/s]

Current Accuracy: 0.2255859375
Starting epoch 3/10, LR = [0.00025]





Step 30, Loss_train 0.08630967140197754


100%|██████████| 19/19 [00:04<00:00,  3.92it/s]


Current Accuracy: 0.2236328125
Starting epoch 4/10, LR = [0.0005]
Step 40, Loss_train 0.0869416892528534
Step 50, Loss_train 0.060862310230731964


100%|██████████| 19/19 [00:04<00:00,  4.03it/s]

Current Accuracy: 0.21533203125
Starting epoch 5/10, LR = [0.000125]





Step 60, Loss_train 0.06449291110038757


100%|██████████| 19/19 [00:04<00:00,  4.03it/s]

Current Accuracy: 0.2158203125
Starting epoch 6/10, LR = [0.00025]





Step 70, Loss_train 0.1312812715768814


100%|██████████| 19/19 [00:04<00:00,  4.01it/s]

Current Accuracy: 0.224609375
Starting epoch 7/10, LR = [6.25e-05]





Step 80, Loss_train 0.03078344836831093


  0%|          | 0/19 [00:00<?, ?it/s]

Step 90, Loss_train 0.0723256766796112


100%|██████████| 19/19 [00:04<00:00,  3.96it/s]

Current Accuracy: 0.2314453125
Starting epoch 8/10, LR = [0.000125]





Step 100, Loss_train 0.05629277601838112


100%|██████████| 19/19 [00:04<00:00,  3.97it/s]

Current Accuracy: 0.23291015625
Starting epoch 9/10, LR = [3.125e-05]





Step 110, Loss_train 0.10318563878536224


100%|██████████| 19/19 [00:04<00:00,  3.96it/s]

Current Accuracy: 0.23291015625
Starting epoch 10/10, LR = [6.25e-05]





Step 120, Loss_train 0.06643359363079071


100%|██████████| 19/19 [00:04<00:00,  4.03it/s]


Current Accuracy: 0.2314453125
Starting epoch 1/10, LR = [0.001]
Step 0, Loss_train 1.946158528327942
Step 10, Loss_train 0.32724887132644653


100%|██████████| 31/31 [00:07<00:00,  4.10it/s]


Current Accuracy: 0.33251953125
Starting epoch 2/10, LR = [0.001]
Step 20, Loss_train 0.13052600622177124


100%|██████████| 31/31 [00:07<00:00,  4.02it/s]


Current Accuracy: 0.33740234375
Starting epoch 3/10, LR = [0.00025]
Step 30, Loss_train 0.14166972041130066


100%|██████████| 31/31 [00:07<00:00,  4.06it/s]


Current Accuracy: 0.3525390625
Starting epoch 4/10, LR = [0.0005]
Step 40, Loss_train 0.15509797632694244
Step 50, Loss_train 0.038382451981306076


100%|██████████| 31/31 [00:07<00:00,  4.03it/s]


Current Accuracy: 0.359375
Starting epoch 5/10, LR = [0.000125]
Step 60, Loss_train 0.06435427069664001


100%|██████████| 31/31 [00:07<00:00,  4.08it/s]


Current Accuracy: 0.36376953125
Starting epoch 6/10, LR = [0.00025]
Step 70, Loss_train 0.043549779802560806


100%|██████████| 31/31 [00:07<00:00,  4.14it/s]


Current Accuracy: 0.3662109375
Starting epoch 7/10, LR = [6.25e-05]
Step 80, Loss_train 0.021374857053160667


  0%|          | 0/31 [00:00<?, ?it/s]

Step 90, Loss_train 0.0741807147860527


100%|██████████| 31/31 [00:07<00:00,  4.11it/s]

Current Accuracy: 0.3662109375
Starting epoch 8/10, LR = [0.000125]





Step 100, Loss_train 0.05111033096909523


100%|██████████| 31/31 [00:07<00:00,  4.07it/s]


Current Accuracy: 0.3681640625
Starting epoch 9/10, LR = [3.125e-05]
Step 110, Loss_train 0.03792368620634079


100%|██████████| 31/31 [00:07<00:00,  4.05it/s]

Current Accuracy: 0.3681640625
Starting epoch 10/10, LR = [6.25e-05]





Step 120, Loss_train 0.041632767766714096


100%|██████████| 31/31 [00:07<00:00,  4.16it/s]


Current Accuracy: 0.3681640625
Starting epoch 1/10, LR = [0.001]
Step 0, Loss_train 2.2543399333953857
Step 10, Loss_train 0.44556495547294617


100%|██████████| 19/19 [00:04<00:00,  4.11it/s]


Current Accuracy: 0.29345703125
Starting epoch 2/10, LR = [0.001]
Step 20, Loss_train 0.15508849918842316


100%|██████████| 19/19 [00:04<00:00,  4.09it/s]

Current Accuracy: 0.23779296875
Starting epoch 3/10, LR = [0.001]





Step 30, Loss_train 0.09122763574123383


100%|██████████| 19/19 [00:04<00:00,  4.10it/s]

Current Accuracy: 0.23876953125
Starting epoch 4/10, LR = [0.00025]





Step 40, Loss_train 0.08012250065803528
Step 50, Loss_train 0.03789059445261955


100%|██████████| 19/19 [00:04<00:00,  4.09it/s]

Current Accuracy: 0.26025390625
Starting epoch 5/10, LR = [0.0005]





Step 60, Loss_train 0.04703801870346069


100%|██████████| 19/19 [00:04<00:00,  4.11it/s]

Current Accuracy: 0.2685546875
Starting epoch 6/10, LR = [0.0005]





Step 70, Loss_train 0.0690833330154419


100%|██████████| 19/19 [00:04<00:00,  4.12it/s]

Current Accuracy: 0.271484375
Starting epoch 7/10, LR = [0.000125]





Step 80, Loss_train 0.057224173098802567


  0%|          | 0/19 [00:00<?, ?it/s]

Step 90, Loss_train 0.03435765206813812


100%|██████████| 19/19 [00:04<00:00,  4.05it/s]


Current Accuracy: 0.27197265625
Starting epoch 8/10, LR = [0.00025]
Step 100, Loss_train 0.02943149209022522


100%|██████████| 19/19 [00:04<00:00,  4.08it/s]

Current Accuracy: 0.271484375
Starting epoch 9/10, LR = [0.00025]





Step 110, Loss_train 0.03922044858336449


100%|██████████| 19/19 [00:04<00:00,  4.03it/s]

Current Accuracy: 0.27001953125
Starting epoch 10/10, LR = [6.25e-05]





Step 120, Loss_train 0.04062008857727051


100%|██████████| 19/19 [00:04<00:00,  4.12it/s]


Current Accuracy: 0.2685546875
Starting epoch 1/10, LR = [0.001]
Step 0, Loss_train 2.1901612281799316
Step 10, Loss_train 0.3927651643753052


100%|██████████| 31/31 [00:07<00:00,  4.16it/s]


Current Accuracy: 0.345703125
Starting epoch 2/10, LR = [0.001]
Step 20, Loss_train 0.2098926156759262


100%|██████████| 31/31 [00:07<00:00,  4.19it/s]


Current Accuracy: 0.3486328125
Starting epoch 3/10, LR = [0.001]
Step 30, Loss_train 0.05548528581857681


100%|██████████| 31/31 [00:07<00:00,  4.17it/s]

Current Accuracy: 0.33203125
Starting epoch 4/10, LR = [0.00025]





Step 40, Loss_train 0.05099150538444519
Step 50, Loss_train 0.05721869319677353


100%|██████████| 31/31 [00:07<00:00,  4.16it/s]


Current Accuracy: 0.3525390625
Starting epoch 5/10, LR = [0.0005]
Step 60, Loss_train 0.07337293028831482


100%|██████████| 31/31 [00:07<00:00,  4.12it/s]


Current Accuracy: 0.35546875
Starting epoch 6/10, LR = [0.0005]
Step 70, Loss_train 0.06036011874675751


100%|██████████| 31/31 [00:07<00:00,  4.15it/s]


Current Accuracy: 0.3583984375
Starting epoch 7/10, LR = [0.000125]
Step 80, Loss_train 0.0374753512442112


  0%|          | 0/31 [00:00<?, ?it/s]

Step 90, Loss_train 0.042490892112255096


100%|██████████| 31/31 [00:07<00:00,  4.14it/s]


Current Accuracy: 0.36181640625
Starting epoch 8/10, LR = [0.00025]
Step 100, Loss_train 0.066545769572258


100%|██████████| 31/31 [00:07<00:00,  4.05it/s]


Current Accuracy: 0.3662109375
Starting epoch 9/10, LR = [0.00025]
Step 110, Loss_train 0.019825201481580734


100%|██████████| 31/31 [00:07<00:00,  4.17it/s]

Current Accuracy: 0.36572265625
Starting epoch 10/10, LR = [6.25e-05]





Step 120, Loss_train 0.020981738343834877


100%|██████████| 31/31 [00:07<00:00,  4.22it/s]


Current Accuracy: 0.36572265625
Starting epoch 1/10, LR = [0.005]
Step 0, Loss_train 1.984547734260559
Step 10, Loss_train 0.2052433043718338


100%|██████████| 19/19 [00:04<00:00,  4.06it/s]


Current Accuracy: 0.2412109375
Starting epoch 2/10, LR = [0.005]
Step 20, Loss_train 0.11226536333560944


100%|██████████| 19/19 [00:04<00:00,  4.08it/s]


Current Accuracy: 0.32373046875
Starting epoch 3/10, LR = [0.00125]
Step 30, Loss_train 0.0426606610417366


100%|██████████| 19/19 [00:04<00:00,  4.13it/s]

Current Accuracy: 0.25048828125
Starting epoch 4/10, LR = [0.0025]





Step 40, Loss_train 0.023617219179868698
Step 50, Loss_train 0.02075241506099701


100%|██████████| 19/19 [00:04<00:00,  4.13it/s]

Current Accuracy: 0.3173828125
Starting epoch 5/10, LR = [0.000625]





Step 60, Loss_train 0.023557722568511963


100%|██████████| 19/19 [00:04<00:00,  4.12it/s]

Current Accuracy: 0.310546875
Starting epoch 6/10, LR = [0.00125]





Step 70, Loss_train 0.00479188933968544


100%|██████████| 19/19 [00:04<00:00,  4.11it/s]


Current Accuracy: 0.30859375
Starting epoch 7/10, LR = [0.0003125]
Step 80, Loss_train 0.012996301054954529


  0%|          | 0/19 [00:00<?, ?it/s]

Step 90, Loss_train 0.010664738714694977


100%|██████████| 19/19 [00:04<00:00,  3.82it/s]

Current Accuracy: 0.310546875
Starting epoch 8/10, LR = [0.000625]





Step 100, Loss_train 0.006366197019815445


100%|██████████| 19/19 [00:04<00:00,  4.09it/s]

Current Accuracy: 0.31982421875
Starting epoch 9/10, LR = [0.00015625]





Step 110, Loss_train 0.002278510481119156


100%|██████████| 19/19 [00:04<00:00,  4.11it/s]

Current Accuracy: 0.31884765625
Starting epoch 10/10, LR = [0.0003125]





Step 120, Loss_train 0.004923749715089798


100%|██████████| 19/19 [00:04<00:00,  4.06it/s]


Current Accuracy: 0.318359375
Starting epoch 1/10, LR = [0.005]
Step 0, Loss_train 2.005140781402588
Step 10, Loss_train 0.203957200050354


100%|██████████| 31/31 [00:07<00:00,  4.07it/s]


Current Accuracy: 0.53466796875
Starting epoch 2/10, LR = [0.005]
Step 20, Loss_train 0.29670149087905884


100%|██████████| 31/31 [00:07<00:00,  4.14it/s]


Current Accuracy: 0.669921875
Starting epoch 3/10, LR = [0.00125]
Step 30, Loss_train 0.06936131417751312


100%|██████████| 31/31 [00:07<00:00,  4.21it/s]

Current Accuracy: 0.56201171875
Starting epoch 4/10, LR = [0.0025]





Step 40, Loss_train 0.06711822748184204
Step 50, Loss_train 0.03940160199999809


100%|██████████| 31/31 [00:07<00:00,  4.18it/s]

Current Accuracy: 0.603515625
Starting epoch 5/10, LR = [0.000625]





Step 60, Loss_train 0.039227135479450226


100%|██████████| 31/31 [00:07<00:00,  4.15it/s]

Current Accuracy: 0.59130859375
Starting epoch 6/10, LR = [0.00125]





Step 70, Loss_train 0.011912494897842407


100%|██████████| 31/31 [00:07<00:00,  4.13it/s]

Current Accuracy: 0.5986328125
Starting epoch 7/10, LR = [0.0003125]





Step 80, Loss_train 0.010454406961798668


  0%|          | 0/31 [00:00<?, ?it/s]

Step 90, Loss_train 0.0236040111631155


100%|██████████| 31/31 [00:07<00:00,  4.14it/s]

Current Accuracy: 0.60302734375
Starting epoch 8/10, LR = [0.000625]





Step 100, Loss_train 0.0020936764776706696


100%|██████████| 31/31 [00:07<00:00,  4.12it/s]

Current Accuracy: 0.5966796875
Starting epoch 9/10, LR = [0.00015625]





Step 110, Loss_train 0.0058573149144649506


100%|██████████| 31/31 [00:07<00:00,  4.16it/s]

Current Accuracy: 0.6005859375
Starting epoch 10/10, LR = [0.0003125]





Step 120, Loss_train 0.0020259059965610504


100%|██████████| 31/31 [00:07<00:00,  4.22it/s]


Current Accuracy: 0.6025390625
Starting epoch 1/10, LR = [0.005]
Step 0, Loss_train 2.009279727935791
Step 10, Loss_train 0.21301713585853577


100%|██████████| 19/19 [00:04<00:00,  4.06it/s]


Current Accuracy: 0.24365234375
Starting epoch 2/10, LR = [0.005]
Step 20, Loss_train 0.09898912906646729


100%|██████████| 19/19 [00:04<00:00,  4.08it/s]


Current Accuracy: 0.275390625
Starting epoch 3/10, LR = [0.005]
Step 30, Loss_train 0.02189050428569317


100%|██████████| 19/19 [00:04<00:00,  4.09it/s]


Current Accuracy: 0.27880859375
Starting epoch 4/10, LR = [0.00125]
Step 40, Loss_train 0.030092205852270126
Step 50, Loss_train 0.014919957146048546


100%|██████████| 19/19 [00:04<00:00,  4.05it/s]


Current Accuracy: 0.28662109375
Starting epoch 5/10, LR = [0.0025]
Step 60, Loss_train 0.010206952691078186


100%|██████████| 19/19 [00:04<00:00,  4.13it/s]


Current Accuracy: 0.28759765625
Starting epoch 6/10, LR = [0.0025]
Step 70, Loss_train 0.010386999696493149


100%|██████████| 19/19 [00:04<00:00,  3.98it/s]

Current Accuracy: 0.2744140625
Starting epoch 7/10, LR = [0.000625]





Step 80, Loss_train 0.0055199600756168365


  0%|          | 0/19 [00:00<?, ?it/s]

Step 90, Loss_train 0.007059052586555481


100%|██████████| 19/19 [00:04<00:00,  4.05it/s]

Current Accuracy: 0.2822265625
Starting epoch 8/10, LR = [0.00125]





Step 100, Loss_train 0.016735605895519257


100%|██████████| 19/19 [00:04<00:00,  4.11it/s]

Current Accuracy: 0.28466796875
Starting epoch 9/10, LR = [0.00125]





Step 110, Loss_train 0.004445914179086685


100%|██████████| 19/19 [00:04<00:00,  4.10it/s]


Current Accuracy: 0.2900390625
Starting epoch 10/10, LR = [0.0003125]
Step 120, Loss_train 0.001730073243379593


100%|██████████| 19/19 [00:04<00:00,  4.05it/s]


Current Accuracy: 0.29541015625
Starting epoch 1/10, LR = [0.005]
Step 0, Loss_train 2.019496202468872
Step 10, Loss_train 0.3314792513847351


100%|██████████| 31/31 [00:07<00:00,  4.15it/s]


Current Accuracy: 0.44873046875
Starting epoch 2/10, LR = [0.005]
Step 20, Loss_train 0.13464894890785217


100%|██████████| 31/31 [00:07<00:00,  4.15it/s]


Current Accuracy: 0.48974609375
Starting epoch 3/10, LR = [0.005]
Step 30, Loss_train 0.025483388453722


100%|██████████| 31/31 [00:07<00:00,  4.18it/s]


Current Accuracy: 0.58984375
Starting epoch 4/10, LR = [0.00125]
Step 40, Loss_train 0.034829795360565186
Step 50, Loss_train 0.005049694329500198


100%|██████████| 31/31 [00:07<00:00,  4.08it/s]


Current Accuracy: 0.5986328125
Starting epoch 5/10, LR = [0.0025]
Step 60, Loss_train 0.014362942427396774


100%|██████████| 31/31 [00:07<00:00,  4.17it/s]

Current Accuracy: 0.59423828125
Starting epoch 6/10, LR = [0.0025]





Step 70, Loss_train 0.003117617219686508


100%|██████████| 31/31 [00:07<00:00,  4.17it/s]

Current Accuracy: 0.55078125
Starting epoch 7/10, LR = [0.000625]





Step 80, Loss_train 0.003047902137041092


  0%|          | 0/31 [00:00<?, ?it/s]

Step 90, Loss_train 0.0048539116978645325


100%|██████████| 31/31 [00:07<00:00,  4.14it/s]

Current Accuracy: 0.5966796875
Starting epoch 8/10, LR = [0.00125]





Step 100, Loss_train 0.004053492099046707


100%|██████████| 31/31 [00:07<00:00,  4.20it/s]


Current Accuracy: 0.60205078125
Starting epoch 9/10, LR = [0.00125]
Step 110, Loss_train 0.0007967203855514526


100%|██████████| 31/31 [00:07<00:00,  4.10it/s]

Current Accuracy: 0.60205078125
Starting epoch 10/10, LR = [0.0003125]





Step 120, Loss_train 0.002783454954624176


100%|██████████| 31/31 [00:07<00:00,  4.09it/s]

Current Accuracy: 0.599609375
Best configuration is:
{'alpha': 'dynamic', 'batch_size': 128, 'epochs': 10, 'gamma': 0.5, 'lr': 0.005, 'step_size': 2}
Highest mean accuracy: 0.496826171875





In [0]:
params = best_conf['params']

source_dataloader, _ = get_data_loaders(source_data, batch_size=params['batch_size'])
_, test_dataloader = get_data_loaders(target_data, batch_size=params['batch_size'])
no_dann = alexdann(pretrained=True)
criterion, _, optimizer, scheduler = init_cnn_objects(no_dann, lr=params['lr'], step_size=params['step_size'], gamma=params['gamma'])

losses_no_dann, accuracy_max = simple_train_test(no_dann, source_dataloader, test_dataloader, criterion, optimizer, scheduler, max_epoch=params['epochs'])
print(f"Photo transfer to Art accuracy no DANN: {accuracy_max}")

Starting epoch 1/10, LR = [0.005]




Step 0, Loss_train 2.5042428970336914
Step 10, Loss_train 0.10183731466531754


100%|██████████| 16/16 [00:04<00:00,  3.81it/s]


Current Accuracy: 0.4033203125
Starting epoch 2/10, LR = [0.005]
Step 20, Loss_train 0.0935046598315239


100%|██████████| 16/16 [00:04<00:00,  3.82it/s]


Current Accuracy: 0.5166015625
Starting epoch 3/10, LR = [0.00125]
Step 30, Loss_train 0.04062642902135849


100%|██████████| 16/16 [00:04<00:00,  3.89it/s]

Current Accuracy: 0.48583984375
Starting epoch 4/10, LR = [0.0025]





Step 40, Loss_train 0.004638992249965668
Step 50, Loss_train 0.02055942639708519


100%|██████████| 16/16 [00:04<00:00,  3.85it/s]

Current Accuracy: 0.47900390625
Starting epoch 5/10, LR = [0.000625]





Step 60, Loss_train 0.011482477188110352


100%|██████████| 16/16 [00:04<00:00,  3.89it/s]

Current Accuracy: 0.490234375
Starting epoch 6/10, LR = [0.00125]





Step 70, Loss_train 0.018621213734149933


100%|██████████| 16/16 [00:04<00:00,  3.87it/s]

Current Accuracy: 0.50146484375
Starting epoch 7/10, LR = [0.0003125]





Step 80, Loss_train 0.0031421221792697906


  0%|          | 0/16 [00:00<?, ?it/s]

Step 90, Loss_train 0.008878670632839203


100%|██████████| 16/16 [00:04<00:00,  3.87it/s]

Current Accuracy: 0.5087890625
Starting epoch 8/10, LR = [0.000625]





Step 100, Loss_train 0.01672378182411194


100%|██████████| 16/16 [00:04<00:00,  3.89it/s]

Current Accuracy: 0.50537109375
Starting epoch 9/10, LR = [0.00015625]





Step 110, Loss_train 0.00646287202835083


100%|██████████| 16/16 [00:04<00:00,  3.82it/s]

Current Accuracy: 0.501953125
Starting epoch 10/10, LR = [0.0003125]





Step 120, Loss_train 0.004503034055233002


100%|██████████| 16/16 [00:04<00:00,  3.86it/s]

Current Accuracy: 0.50048828125
Photo transfer to Art accuracy no DANN: 0.5166015625





**Training with DANN**

In [0]:
import math
def dann_train_test(model, source_dataloader, target_dataloader, test_dataloader, 
                    class_criterion, domain_criterion, optimizer, scheduler, max_epoch=NUM_EPOCHS, alpha=ALPHA, file_path=BASE_FILE_PATH, device=DEVICE):
    class_losses_y = []
    domain_losses_d = []
    accuracies = []
    loss_min = -1
    accuracy_max = 0
    count_diverge = 0
    
    model = model.to(device)
    cudnn.benchmark
    
    current_step = 0
    for epoch in range(max_epoch):
        print('Starting epoch {}/{}, LR = {}'.format(epoch+1, max_epoch, scheduler.get_lr()))
        class_loss, domain_loss, current_step = dann_train_src_target(model, source_dataloader, 
                                          target_dataloader, optimizer, class_criterion, domain_criterion, current_step,
                                          epoch, max_epoch, alpha=alpha, device=device)
        
        if math.isnan(class_loss) or math.isnan(domain_loss):
          count_diverge += 1
          if count_diverge >= 3:
            print("EARLY STOPPING")
            break;
        class_losses_y.append(class_loss)
        domain_losses_d.append(domain_loss)

        """if loss_min == -1 or loss_min > class_loss:
            loss_min = class_loss
            torch.save(model, f"{file_path}_best_model_dann.pth")"""
        accuracy = test_target(model, test_dataloader, class_criterion, device) / float(len(target_data))
        accuracies.append(accuracy)
        print(f"Accuracy on test set: {accuracy}")
        if accuracy_max < accuracy:
          accuracy_max = accuracy
          torch.save(model, f"{file_path}_best_model_dann.pth")
        scheduler.step()
        
    return class_losses_y, domain_losses_d, accuracies

In [0]:
params = {
    "lr": [1e-3, 5e-3],
    "batch_size": [128],
    "epochs": [10],
    "step_size": [2, 3],
    "gamma": [0.5],
    "alpha": ['dynamic']
}
grid = ParameterGrid(params)
dann_results = []
for config in grid:
  source_dataloader, _ = get_data_loaders(source_data, config['batch_size'])
  cartoon_dataloader, cartoon_test_dataloader = get_data_loaders(cartoon_data, config['batch_size'])
  sketch_dataloader, sketch_test_dataloader = get_data_loaders(sketch_data, config['batch_size'])
  dann_cartoon = alexdann(pretrained=True)
  criterion_cartoon_src, criterion_cartoon_tgt, optimizer_cartoon, scheduler_cartoon = init_cnn_objects(dann_cartoon, lr=config['lr'], 
                                                                          step_size=config['step_size'], gamma=config['gamma'])
  dann_sketch = alexdann(pretrained=True)
  criterion_sketch_src, criterion_sketch_tgt, optimizer_sketch, scheduler_sketch = init_cnn_objects(dann_sketch, lr=config['lr'], 
                                                                          step_size=config['step_size'], gamma=config['gamma'])

  class_losses_cartoon, domain_losses_cartoon, accuracies_cartoon = dann_train_test(
      dann_cartoon, source_dataloader, cartoon_dataloader, cartoon_test_dataloader,
      criterion_cartoon_src, criterion_cartoon_tgt, optimizer_cartoon, scheduler_cartoon, max_epoch=config['epochs'], alpha=config['alpha']
  )

  class_losses_sketch, domain_losses_sketch, accuracies_sketch = dann_train_test(
      dann_sketch, source_dataloader, sketch_dataloader, sketch_test_dataloader,
      criterion_sketch_src, criterion_sketch_tgt, optimizer_sketch, scheduler_sketch, max_epoch=config['epochs'], alpha=config['alpha']
  )

  curr_result = {'params': config, 'cartoon_losses': (class_losses_cartoon, domain_losses_cartoon), 'sketch_losses': (class_losses_sketch, domain_losses_sketch),
                 'best_acc': np.mean([max(accuracies_cartoon), max(accuracies_sketch)])}
  dann_results.append(curr_result)
  
best_conf = max(dann_results, key=lambda x: x['best_acc'])
print(f"Best configuration is:\n{best_conf['params']}")
print(f"Highest mean accuracy: {best_conf['best_acc']}")


Starting epoch 1/10, LR = [0.001]




Step 0
Class Loss 2.159154176712036, Domain Loss 3.7594082355499268
Step 10
Class Loss 0.45151984691619873, Domain Loss 0.2007991373538971


100%|██████████| 19/19 [00:05<00:00,  3.38it/s]


Accuracy on test set: 0.28857421875
Starting epoch 2/10, LR = [0.001]
Step 20
Class Loss 0.1243690550327301, Domain Loss 0.045085158199071884


100%|██████████| 19/19 [00:05<00:00,  3.38it/s]

Accuracy on test set: 0.28125
Starting epoch 3/10, LR = [0.00025]





Step 30
Class Loss 0.14589256048202515, Domain Loss 0.13656994700431824


100%|██████████| 19/19 [00:06<00:00,  2.94it/s]

Accuracy on test set: 0.24560546875
Starting epoch 4/10, LR = [0.0005]





Step 40
Class Loss 0.07491409033536911, Domain Loss 0.1646413803100586
Step 50
Class Loss 0.08851169794797897, Domain Loss 0.21914857625961304


100%|██████████| 19/19 [00:05<00:00,  3.35it/s]

Accuracy on test set: 0.2412109375
Starting epoch 5/10, LR = [0.000125]





Step 60
Class Loss 0.04667208716273308, Domain Loss 0.0847468227148056


100%|██████████| 19/19 [00:06<00:00,  3.03it/s]

Accuracy on test set: 0.25
Starting epoch 6/10, LR = [0.00025]





Step 70
Class Loss 0.07902085781097412, Domain Loss 0.07210561633110046


100%|██████████| 19/19 [00:06<00:00,  3.01it/s]

Accuracy on test set: 0.2587890625
Starting epoch 7/10, LR = [6.25e-05]





Step 80
Class Loss 0.10586325824260712, Domain Loss 0.054199181497097015
Step 90
Class Loss 0.10880999267101288, Domain Loss 0.086285799741745


100%|██████████| 19/19 [00:06<00:00,  3.01it/s]

Accuracy on test set: 0.26416015625
Starting epoch 8/10, LR = [0.000125]





Step 100
Class Loss 0.06568780541419983, Domain Loss 0.08492594957351685


100%|██████████| 19/19 [00:06<00:00,  3.03it/s]


Accuracy on test set: 0.259765625
Starting epoch 9/10, LR = [3.125e-05]
Step 110
Class Loss 0.04868962988257408, Domain Loss 0.1574188768863678


100%|██████████| 19/19 [00:06<00:00,  3.00it/s]

Accuracy on test set: 0.26220703125
Starting epoch 10/10, LR = [6.25e-05]





Step 120
Class Loss 0.03610925376415253, Domain Loss 0.03150957077741623


100%|██████████| 19/19 [00:06<00:00,  3.01it/s]


Accuracy on test set: 0.26416015625
Starting epoch 1/10, LR = [0.001]
Step 0
Class Loss 2.1323788166046143, Domain Loss 4.198472499847412
Step 10
Class Loss 0.44734469056129456, Domain Loss 0.032406169921159744


100%|██████████| 31/31 [00:08<00:00,  3.74it/s]


Accuracy on test set: 0.41748046875
Starting epoch 2/10, LR = [0.001]
Step 20
Class Loss 0.1534384787082672, Domain Loss 8.256733417510986e-05


100%|██████████| 31/31 [00:08<00:00,  3.80it/s]

Accuracy on test set: 0.41650390625
Starting epoch 3/10, LR = [0.00025]





Step 30
Class Loss 0.14491023123264313, Domain Loss 0.00045587122440338135


100%|██████████| 31/31 [00:08<00:00,  3.74it/s]


Accuracy on test set: 0.43408203125
Starting epoch 4/10, LR = [0.0005]
Step 40
Class Loss 0.09398488700389862, Domain Loss 2.6695430278778076e-05
Step 50
Class Loss 0.07851475477218628, Domain Loss 0.008645139634609222


100%|██████████| 31/31 [00:08<00:00,  3.76it/s]

Accuracy on test set: 0.4306640625
Starting epoch 5/10, LR = [0.000125]





Step 60
Class Loss 0.05833827704191208, Domain Loss 0.0068552494049072266


100%|██████████| 31/31 [00:08<00:00,  3.85it/s]

Accuracy on test set: 0.42333984375
Starting epoch 6/10, LR = [0.00025]





Step 70
Class Loss 0.10419528931379318, Domain Loss 0.00022746622562408447


100%|██████████| 31/31 [00:08<00:00,  3.80it/s]

Accuracy on test set: 0.42138671875
Starting epoch 7/10, LR = [6.25e-05]





Step 80
Class Loss 0.06938070803880692, Domain Loss 8.32974910736084e-05
Step 90
Class Loss 0.015189025551080704, Domain Loss 1.9572675228118896e-05


100%|██████████| 31/31 [00:08<00:00,  3.79it/s]

Accuracy on test set: 0.419921875
Starting epoch 8/10, LR = [0.000125]





Step 100
Class Loss 0.0651884451508522, Domain Loss 2.2649765014648438e-05


100%|██████████| 31/31 [00:08<00:00,  3.75it/s]

Accuracy on test set: 0.42236328125
Starting epoch 9/10, LR = [3.125e-05]





Step 110
Class Loss 0.06314189732074738, Domain Loss 2.60770320892334e-07


100%|██████████| 31/31 [00:08<00:00,  3.81it/s]

Accuracy on test set: 0.4208984375
Starting epoch 10/10, LR = [6.25e-05]





Step 120
Class Loss 0.06457735598087311, Domain Loss 8.27014446258545e-06


100%|██████████| 31/31 [00:08<00:00,  3.79it/s]


Accuracy on test set: 0.41845703125
Starting epoch 1/10, LR = [0.001]
Step 0
Class Loss 2.4016268253326416, Domain Loss 3.7950029373168945
Step 10
Class Loss 0.39654719829559326, Domain Loss 0.15932771563529968


100%|██████████| 19/19 [00:05<00:00,  3.70it/s]


Accuracy on test set: 0.3056640625
Starting epoch 2/10, LR = [0.001]
Step 20
Class Loss 0.13102968037128448, Domain Loss 0.08987453579902649


100%|██████████| 19/19 [00:05<00:00,  3.70it/s]

Accuracy on test set: 0.28125
Starting epoch 3/10, LR = [0.001]





Step 30
Class Loss 0.06773464381694794, Domain Loss 0.1265941709280014


100%|██████████| 19/19 [00:05<00:00,  3.65it/s]

Accuracy on test set: 0.2421875
Starting epoch 4/10, LR = [0.00025]





Step 40
Class Loss 0.035446736961603165, Domain Loss 0.08078469336032867
Step 50
Class Loss 0.054708078503608704, Domain Loss 0.09657035022974014


100%|██████████| 19/19 [00:05<00:00,  3.68it/s]

Accuracy on test set: 0.248046875
Starting epoch 5/10, LR = [0.0005]





Step 60
Class Loss 0.04839436709880829, Domain Loss 0.08114340156316757


100%|██████████| 19/19 [00:05<00:00,  3.70it/s]

Accuracy on test set: 0.29736328125
Starting epoch 6/10, LR = [0.0005]





Step 70
Class Loss 0.08176718652248383, Domain Loss 0.055620498955249786


100%|██████████| 19/19 [00:05<00:00,  3.68it/s]

Accuracy on test set: 0.29345703125
Starting epoch 7/10, LR = [0.000125]





Step 80
Class Loss 0.0334208607673645, Domain Loss 0.08032817393541336
Step 90
Class Loss 0.058445531874895096, Domain Loss 0.08037639409303665


100%|██████████| 19/19 [00:05<00:00,  3.65it/s]

Accuracy on test set: 0.29150390625
Starting epoch 8/10, LR = [0.00025]





Step 100
Class Loss 0.03412403538823128, Domain Loss 0.12461169064044952


100%|██████████| 19/19 [00:05<00:00,  3.64it/s]

Accuracy on test set: 0.279296875
Starting epoch 9/10, LR = [0.00025]





Step 110
Class Loss 0.03582194074988365, Domain Loss 0.024164550006389618


100%|██████████| 19/19 [00:05<00:00,  3.60it/s]

Accuracy on test set: 0.27685546875
Starting epoch 10/10, LR = [6.25e-05]





Step 120
Class Loss 0.0424511581659317, Domain Loss 0.06794572621583939


100%|██████████| 19/19 [00:05<00:00,  3.61it/s]

Accuracy on test set: 0.27685546875
Starting epoch 1/10, LR = [0.001]





Step 0
Class Loss 2.1758668422698975, Domain Loss 4.690292835235596
Step 10
Class Loss 0.32351213693618774, Domain Loss 0.00022855401039123535


100%|██████████| 31/31 [00:08<00:00,  3.47it/s]


Accuracy on test set: 0.53759765625
Starting epoch 2/10, LR = [0.001]
Step 20
Class Loss 0.10607564449310303, Domain Loss 5.2675604820251465e-06


100%|██████████| 31/31 [00:08<00:00,  3.59it/s]

Accuracy on test set: 0.52587890625
Starting epoch 3/10, LR = [0.001]





Step 30
Class Loss 0.08507214486598969, Domain Loss 0.00044621527194976807


100%|██████████| 31/31 [00:08<00:00,  3.54it/s]

Accuracy on test set: 0.47021484375
Starting epoch 4/10, LR = [0.00025]





Step 40
Class Loss 0.07326200604438782, Domain Loss 5.02467155456543e-05
Step 50
Class Loss 0.08197817206382751, Domain Loss 8.195638656616211e-08


100%|██████████| 31/31 [00:08<00:00,  3.61it/s]

Accuracy on test set: 0.47021484375
Starting epoch 5/10, LR = [0.0005]





Step 60
Class Loss 0.055949073284864426, Domain Loss 1.4901161193847656e-08


100%|██████████| 31/31 [00:08<00:00,  3.59it/s]

Accuracy on test set: 0.47607421875
Starting epoch 6/10, LR = [0.0005]





Step 70
Class Loss 0.04477832838892937, Domain Loss 2.0012259483337402e-05


100%|██████████| 31/31 [00:08<00:00,  3.57it/s]

Accuracy on test set: 0.51953125
Starting epoch 7/10, LR = [0.000125]





Step 80
Class Loss 0.05425666645169258, Domain Loss 0.03893418610095978
Step 90
Class Loss 0.04717155545949936, Domain Loss 7.286667823791504e-06


100%|██████████| 31/31 [00:08<00:00,  3.56it/s]

Accuracy on test set: 0.517578125
Starting epoch 8/10, LR = [0.00025]





Step 100
Class Loss 0.049668096005916595, Domain Loss 0.011038452386856079


100%|██████████| 31/31 [00:08<00:00,  3.53it/s]

Accuracy on test set: 0.52197265625
Starting epoch 9/10, LR = [0.00025]





Step 110
Class Loss 0.05148801580071449, Domain Loss 3.993511199951172e-06


100%|██████████| 31/31 [00:08<00:00,  3.48it/s]

Accuracy on test set: 0.5048828125
Starting epoch 10/10, LR = [6.25e-05]





Step 120
Class Loss 0.02212756872177124, Domain Loss 1.817941665649414e-06


100%|██████████| 31/31 [00:08<00:00,  3.54it/s]


Accuracy on test set: 0.5078125
Starting epoch 1/10, LR = [0.005]
Step 0
Class Loss 2.131915330886841, Domain Loss 3.82125186920166
Step 10
Class Loss 0.20021958649158478, Domain Loss 0.021517805755138397


100%|██████████| 19/19 [00:05<00:00,  3.55it/s]


Accuracy on test set: 0.24755859375
Starting epoch 2/10, LR = [0.005]
Step 20
Class Loss 0.10147754102945328, Domain Loss 0.4521730840206146


100%|██████████| 19/19 [00:05<00:00,  3.50it/s]


Accuracy on test set: 0.5126953125
Starting epoch 3/10, LR = [0.00125]
Step 30
Class Loss 0.07465137541294098, Domain Loss 0.6385120153427124


100%|██████████| 19/19 [00:05<00:00,  3.54it/s]

Accuracy on test set: 0.23583984375
Starting epoch 4/10, LR = [0.0025]





Step 40
Class Loss 0.061328742653131485, Domain Loss 0.2681695222854614
Step 50
Class Loss 0.04178217798471451, Domain Loss 0.6906050443649292


100%|██████████| 19/19 [00:05<00:00,  3.53it/s]

Accuracy on test set: 0.49658203125
Starting epoch 5/10, LR = [0.000625]





Step 60
Class Loss 0.16796454787254333, Domain Loss 3.2730884552001953


100%|██████████| 19/19 [00:05<00:00,  3.51it/s]

Accuracy on test set: 0.30419921875
Starting epoch 6/10, LR = [0.00125]





Step 70
Class Loss 5.275023937225342, Domain Loss 15.920379638671875


100%|██████████| 19/19 [00:05<00:00,  3.52it/s]

Accuracy on test set: 0.14013671875
Starting epoch 7/10, LR = [0.0003125]





Step 80
Class Loss 86.57444763183594, Domain Loss 14.528834342956543
Step 90
Class Loss 2.3437533378601074, Domain Loss 2.0586156845092773


100%|██████████| 19/19 [00:05<00:00,  3.51it/s]

Accuracy on test set: 0.140625
Starting epoch 8/10, LR = [0.000625]





Step 100
Class Loss 1.8150150812116976e+19, Domain Loss 2.3052006068115633e+30


100%|██████████| 19/19 [00:05<00:00,  3.49it/s]

Accuracy on test set: 0.18994140625
Starting epoch 9/10, LR = [0.00015625]





Step 110
Class Loss nan, Domain Loss nan


100%|██████████| 19/19 [00:05<00:00,  3.50it/s]

Accuracy on test set: 0.18994140625
Starting epoch 10/10, LR = [0.0003125]





Step 120
Class Loss nan, Domain Loss nan
EARLY STOPPING
Starting epoch 1/10, LR = [0.005]
Step 0
Class Loss 2.0286717414855957, Domain Loss 5.408343315124512
Step 10
Class Loss 0.2839500308036804, Domain Loss 1.043081283569336e-07


100%|██████████| 31/31 [00:09<00:00,  3.20it/s]


Accuracy on test set: 0.45703125
Starting epoch 2/10, LR = [0.005]
Step 20
Class Loss 0.3123788833618164, Domain Loss 0.010232031345367432


100%|██████████| 31/31 [00:09<00:00,  3.13it/s]


Accuracy on test set: 0.72119140625
Starting epoch 3/10, LR = [0.00125]
Step 30
Class Loss 0.016688965260982513, Domain Loss 0.0


100%|██████████| 31/31 [00:09<00:00,  3.23it/s]

Accuracy on test set: 0.61181640625
Starting epoch 4/10, LR = [0.0025]





Step 40
Class Loss 0.03440393880009651, Domain Loss 0.0
Step 50
Class Loss 0.01783856749534607, Domain Loss 0.0


100%|██████████| 31/31 [00:09<00:00,  3.12it/s]

Accuracy on test set: 0.5849609375
Starting epoch 5/10, LR = [0.000625]





Step 60
Class Loss 0.02821243181824684, Domain Loss 3.1888484954833984e-06


100%|██████████| 31/31 [00:09<00:00,  3.20it/s]

Accuracy on test set: 0.6005859375
Starting epoch 6/10, LR = [0.00125]





Step 70
Class Loss 0.012601740658283234, Domain Loss 0.0


100%|██████████| 31/31 [00:09<00:00,  3.18it/s]

Accuracy on test set: 0.6044921875
Starting epoch 7/10, LR = [0.0003125]





Step 80
Class Loss 0.0048980191349983215, Domain Loss 0.0
Step 90
Class Loss 0.009098511189222336, Domain Loss 0.0


100%|██████████| 31/31 [00:09<00:00,  3.17it/s]

Accuracy on test set: 0.603515625
Starting epoch 8/10, LR = [0.000625]





Step 100
Class Loss 0.005657710134983063, Domain Loss 0.0


100%|██████████| 31/31 [00:09<00:00,  3.15it/s]

Accuracy on test set: 0.60546875
Starting epoch 9/10, LR = [0.00015625]





Step 110
Class Loss 0.008107181638479233, Domain Loss 0.0


100%|██████████| 31/31 [00:09<00:00,  3.13it/s]

Accuracy on test set: 0.607421875
Starting epoch 10/10, LR = [0.0003125]





Step 120
Class Loss 0.008991900831460953, Domain Loss 0.0


100%|██████████| 31/31 [00:09<00:00,  3.12it/s]


Accuracy on test set: 0.60400390625
Starting epoch 1/10, LR = [0.005]
Step 0
Class Loss 2.124458074569702, Domain Loss 4.689288139343262
Step 10
Class Loss 0.2415754497051239, Domain Loss 0.3377162218093872


100%|██████████| 19/19 [00:05<00:00,  3.18it/s]


Accuracy on test set: 0.1728515625
Starting epoch 2/10, LR = [0.005]
Step 20
Class Loss 0.2181115746498108, Domain Loss 0.11206312477588654


100%|██████████| 19/19 [00:05<00:00,  3.18it/s]


Accuracy on test set: 0.37255859375
Starting epoch 3/10, LR = [0.005]
Step 30
Class Loss 0.11189925670623779, Domain Loss 0.15085381269454956


100%|██████████| 19/19 [00:05<00:00,  3.39it/s]

Accuracy on test set: 0.28759765625
Starting epoch 4/10, LR = [0.00125]





Step 40
Class Loss 0.04479313641786575, Domain Loss 0.2484806776046753
Step 50
Class Loss 0.034663908183574677, Domain Loss 0.36681029200553894


100%|██████████| 19/19 [00:05<00:00,  3.39it/s]

Accuracy on test set: 0.3095703125
Starting epoch 5/10, LR = [0.0025]





Step 60
Class Loss 0.24391672015190125, Domain Loss 1.9459072351455688


100%|██████████| 19/19 [00:05<00:00,  3.32it/s]

Accuracy on test set: 0.32177734375
Starting epoch 6/10, LR = [0.0025]





Step 70
Class Loss 11.85999870300293, Domain Loss 35.89902877807617


100%|██████████| 19/19 [00:05<00:00,  3.44it/s]

Accuracy on test set: 0.18994140625
Starting epoch 7/10, LR = [0.000625]





Step 80
Class Loss nan, Domain Loss nan
Step 90
Class Loss nan, Domain Loss nan


100%|██████████| 19/19 [00:05<00:00,  3.44it/s]

Accuracy on test set: 0.18994140625
Starting epoch 8/10, LR = [0.00125]





Step 100
Class Loss nan, Domain Loss nan
EARLY STOPPING
Starting epoch 1/10, LR = [0.005]
Step 0
Class Loss 2.103686809539795, Domain Loss 4.341101169586182
Step 10
Class Loss 0.12551315128803253, Domain Loss 4.0978193283081055e-06


100%|██████████| 31/31 [00:10<00:00,  2.99it/s]


Accuracy on test set: 0.36865234375
Starting epoch 2/10, LR = [0.005]
Step 20
Class Loss 0.17424973845481873, Domain Loss 0.0


100%|██████████| 31/31 [00:09<00:00,  3.14it/s]


Accuracy on test set: 0.37109375
Starting epoch 3/10, LR = [0.005]
Step 30
Class Loss 0.04116667062044144, Domain Loss 0.0


100%|██████████| 31/31 [00:09<00:00,  3.16it/s]


Accuracy on test set: 0.494140625
Starting epoch 4/10, LR = [0.00125]
Step 40
Class Loss 0.02938031405210495, Domain Loss 0.0
Step 50
Class Loss 0.03312504291534424, Domain Loss 5.960464477539063e-08


100%|██████████| 31/31 [00:09<00:00,  3.14it/s]


Accuracy on test set: 0.53466796875
Starting epoch 5/10, LR = [0.0025]
Step 60
Class Loss 0.016442783176898956, Domain Loss 0.0


100%|██████████| 31/31 [00:09<00:00,  3.15it/s]

Accuracy on test set: 0.5263671875
Starting epoch 6/10, LR = [0.0025]





Step 70
Class Loss 0.003584064543247223, Domain Loss 0.0


100%|██████████| 31/31 [00:09<00:00,  3.15it/s]


Accuracy on test set: 0.56005859375
Starting epoch 7/10, LR = [0.000625]
Step 80
Class Loss 0.010727953165769577, Domain Loss 0.08107149600982666
Step 90
Class Loss 0.009785372763872147, Domain Loss 5.047023296356201e-05


100%|██████████| 31/31 [00:09<00:00,  3.14it/s]

Accuracy on test set: 0.55224609375
Starting epoch 8/10, LR = [0.00125]





Step 100
Class Loss 0.011454183608293533, Domain Loss 0.004133790731430054


100%|██████████| 31/31 [00:09<00:00,  3.13it/s]

Accuracy on test set: 0.52099609375
Starting epoch 9/10, LR = [0.00125]





Step 110
Class Loss 3.8198084831237793, Domain Loss 63.0916862487793


100%|██████████| 31/31 [00:09<00:00,  3.10it/s]

Accuracy on test set: 0.376953125
Starting epoch 10/10, LR = [0.0003125]





Step 120
Class Loss nan, Domain Loss nan


100%|██████████| 31/31 [00:09<00:00,  3.11it/s]

Accuracy on test set: 0.376953125
Best configuration is:
{'alpha': 'dynamic', 'batch_size': 128, 'epochs': 10, 'gamma': 0.5, 'lr': 0.005, 'step_size': 2}
Highest mean accuracy: 0.616943359375





In [0]:
params = best_conf['params']

source_dataloader, _ = get_data_loaders(source_data, params['batch_size'])
target_dataloader, test_dataloader = get_data_loaders(target_data, params['batch_size'])

dann = alexdann(pretrained=True)
class_criterion, domain_criterion, optimizer, scheduler = init_cnn_objects(dann, lr=params['lr'], step_size=params['step_size'],
                                                                           gamma=params['gamma'])
class_losses_y, domain_losses_d, accuracies = dann_train_test(dann, source_dataloader, target_dataloader,
                          test_dataloader, class_criterion, domain_criterion, optimizer, scheduler, max_epoch=params['epochs'], alpha=params['alpha'])

print(f"Photo transfer to Art accuracy DANN: {max(accuracies)}")

Starting epoch 1/10, LR = [0.005]




Step 0
Class Loss 2.2315292358398438, Domain Loss 4.035977840423584
Step 10
Class Loss 0.13638703525066376, Domain Loss 0.2569566071033478


100%|██████████| 16/16 [00:06<00:00,  2.67it/s]


Accuracy on test set: 0.47216796875
Starting epoch 2/10, LR = [0.005]
Step 20
Class Loss 0.10701119899749756, Domain Loss 0.753000020980835


100%|██████████| 16/16 [00:05<00:00,  2.76it/s]


Accuracy on test set: 0.4775390625
Starting epoch 3/10, LR = [0.00125]
Step 30
Class Loss 0.09845760464668274, Domain Loss 0.271810382604599


100%|██████████| 16/16 [00:05<00:00,  2.98it/s]


Accuracy on test set: 0.5009765625
Starting epoch 4/10, LR = [0.0025]
Step 40
Class Loss 0.0423257052898407, Domain Loss 0.26805007457733154
Step 50
Class Loss 0.028122903779149055, Domain Loss 0.22108566761016846


100%|██████████| 16/16 [00:05<00:00,  3.06it/s]


Accuracy on test set: 0.5322265625
Starting epoch 5/10, LR = [0.000625]
Step 60
Class Loss 0.03023330122232437, Domain Loss 0.25207847356796265


100%|██████████| 16/16 [00:05<00:00,  3.08it/s]

Accuracy on test set: 0.513671875
Starting epoch 6/10, LR = [0.00125]





Step 70
Class Loss 0.02023480460047722, Domain Loss 0.24564605951309204


100%|██████████| 16/16 [00:05<00:00,  3.11it/s]

Accuracy on test set: 0.49609375
Starting epoch 7/10, LR = [0.0003125]





Step 80
Class Loss 0.03136119991540909, Domain Loss 0.324157178401947
Step 90
Class Loss 0.025508960708975792, Domain Loss 0.40186041593551636


100%|██████████| 16/16 [00:05<00:00,  2.99it/s]

Accuracy on test set: 0.48779296875
Starting epoch 8/10, LR = [0.000625]





Step 100
Class Loss 0.048933450132608414, Domain Loss 0.5279905796051025


100%|██████████| 16/16 [00:05<00:00,  3.08it/s]

Accuracy on test set: 0.48095703125
Starting epoch 9/10, LR = [0.00015625]





Step 110
Class Loss 0.04508736729621887, Domain Loss 0.44411158561706543


100%|██████████| 16/16 [00:05<00:00,  3.00it/s]

Accuracy on test set: 0.44970703125
Starting epoch 10/10, LR = [0.0003125]





Step 120
Class Loss 0.04244048893451691, Domain Loss 0.4897080063819885


100%|██████████| 16/16 [00:05<00:00,  3.06it/s]

Accuracy on test set: 0.4462890625
Photo transfer to Art accuracy DANN: 0.5322265625



