In [None]:
# PACS details
domains = ['photo', 'art_painting', 'cartoon', 'sketch']
classes = ['dog', 'elephant', 'giraffe', 'guitar', 'horse','house', 'person']

# Set parameters

In [None]:
import torch
import torchvision.transforms as transforms

from torchvision import models
from torchvision.datasets import ImageFolder, DatasetFolder
import torchvision.datasets as Datasets
from torch.utils.data import DataLoader
from torch import nn, optim

#Modify
from utils import *
import os


import numpy as np
import matplotlib.pyplot as plt



##############################
# Training Setting
##############################
used_model = 'resnet18' 
dataset ='pacs'
save_name = 'IDCL'
pacs_ver = 'pacs_official_split' 
number_of_tests = 20

##############################
# Hyper-parameters
##############################

is_pretrained = True
color_jitter = True

epochs = 30
batch_size = 128
lr = 4e-3
lr_decay_epoch = [24]
lr_decay_gamma = 0.1
gpu_num = 0

torch.manual_seed(0)
torch.cuda.manual_seed(0)

device= torch.device('cpu')
use_gpu = torch.cuda.is_available()
if use_gpu:
    print("Using CUDA")
    device = torch.device("cuda:{}".format(gpu_num))
print(device)

criterion = nn.CrossEntropyLoss().to(device)


#Modify
train_tf, test_tf = get_tf(color_jitter, augment=True)

# save model setting
model_settings={
    "used_model" : used_model,
    "dataset" : dataset,
    "save_name" : save_name,
    "pacs_ver" : pacs_ver,
    "number_of_tests" : number_of_tests,
    "epochs" : epochs,
    "batch_size" : batch_size,
    "is_pretrained" : is_pretrained,
    "color_jitter" : color_jitter,
    "lr" : lr,
    "lr_decay_epoch" : lr_decay_epoch,
    "lr_decay_gamma" : lr_decay_gamma,
    "gpu_num" : gpu_num
}

# Functions

In [None]:
def setting(test_domain_idx, domains, batch_size, is_pretrained, train_tf, test_tf, used_model,pacs_ver):
    
    check = 1
    train_set = 0
    val_set = 0
    check_limit = 3
    
    for i in range(4):
        if check > check_limit:
            break
        if i==test_domain_idx:
            continue
        
        temp = ImageFolder(root=os.path.join('{}/train'.format(pacs_ver),domains[i]),
                           transform = train_tf)
        
        temp_val = ImageFolder(root=os.path.join('{}/val'.format(pacs_ver),domains[i]),
                               transform = test_tf)
        if check==1:
            train_set = temp
            val_set = temp_val
        else:
            train_set += temp
            val_set += temp_val
        
        if check==1:
            train_set_stage1 = train_set
            val_set_stage1 = val_set
        elif check==2:
            train_set_stage2 = train_set
            val_set_stage2 = val_set
        elif check==3:
            train_set_stage3 = train_set
            val_set_stage3 = val_set
        
        check += 1
    
    
    test_set = ImageFolder(root=os.path.join('{}/test'.format(pacs_ver),domains[test_domain_idx]), transform = test_tf)
    
    print('stage1 (train,val):',len(train_set_stage1),len(val_set_stage1))
    print('stage2 (train,val):',len(train_set_stage2),len(val_set_stage2))
    print('stage3 (train,val):',len(train_set_stage3),len(val_set_stage3))
    print('test :',len(test_set))
    
    train_loader_stage1 = DataLoader(train_set_stage1, batch_size=batch_size, shuffle=True, num_workers=6)
    val_loader_stage1 = DataLoader(val_set_stage1, batch_size=batch_size, shuffle=True, num_workers=6)
    
    train_loader_stage2 = DataLoader(train_set_stage2, batch_size=batch_size, shuffle=True, num_workers=6)
    val_loader_stage2 = DataLoader(val_set_stage2, batch_size=batch_size, shuffle=True, num_workers=6)
    
    train_loader_stage3 = DataLoader(train_set_stage3, batch_size=batch_size, shuffle=True, num_workers=6)
    val_loader_stage3 = DataLoader(val_set_stage3, batch_size=batch_size, shuffle=True, num_workers=6)
    
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=6)

    
    if used_model=='vgg16':
        print('vgg16')
        model = models.vgg16(pretrained=is_pretrained).cuda()
        model.classifier[6].out_features=7
    elif used_model=='inceptionv3':
        model = models.inception_v3(pretrained=is_pretrained).cuda()
        model.AuxLogits.fc.out_features = 7
        model.fc.out_features=7
    elif used_model=='resnet18':
        model = models.resnet18(pretrained=is_pretrained)        
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs,7)
        model = model.to(device)
    else:
        raise NotImplementedError

    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_decay_epoch, gamma= lr_decay_gamma)  
    
    return  train_loader_stage1,train_loader_stage2, train_loader_stage3,val_loader_stage1, val_loader_stage2, val_loader_stage3, test_loader, optimizer, model, lr_scheduler


    

# Automation

In [None]:
save_model_setting(model_settings,used_model,domains,dataset,save_name)

for i in range(1,number_of_tests+1):
    try_check = i
    
    for test_idx in range(4):
        
        ##########################
        #### Training Setting ####
        ##########################

        tl_stage1,tl_stage2, tl_stage3, vl_stage1, vl_stage2, vl_stage3, test_loader, optimizer, model, lr_scheduler = setting(
            test_idx, domains, batch_size, is_pretrained, train_tf, test_tf, used_model, pacs_ver
        )

        save_dir = save_route(test_idx, domains, dataset, save_name, used_model)

        try:
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
        except:
            print('Error : Creating directory. '+ save_dir)
        
        ##########################
        ####     Training     ####
        ##########################

        model, losses, accuracies = do_training(
            device, epochs, model,optimizer, criterion, 
            tl_stage1,tl_stage2, tl_stage3,
            vl_stage1, vl_stage2, vl_stage3,
            lr_scheduler
        )
        do_test(device, model,criterion, test_loader,used_model, save_dir, try_check)
            
        plotting(losses, accuracies, used_model, save_dir, is_pretrained, try_check)
        save_model(model, used_model, save_dir, is_pretrained, try_check)
        
        

