In [1]:
%load_ext kedro.extras.extensions.ipython

2022-03-26 16:25:40,611 - kedro.framework.session.store - INFO - `read()` not implemented for `BaseSessionStore`. Assuming empty store.
2022-03-26 16:25:41,521 - root - INFO - ** Kedro project MEDGC Tesis
2022-03-26 16:25:41,522 - root - INFO - Defined global variable `context`, `session`, `catalog` and `pipelines`


In [2]:
import time
import sys
import argparse
import shutil
import os.path as osp

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision.transforms as T

sys.path.append('../../..')
from dalib.modules.domain_discriminator import DomainDiscriminator
from dalib.adaptation.dann import DomainAdversarialLoss, ImageClassifier
from common.utils.data import ForeverDataIterator
from common.utils.metric import accuracy, ConfusionMatrix
from common.utils.meter import AverageMeter, ProgressMeter
from common.utils.logger import CompleteLogger
from common.utils.analysis import collect_feature, tsne, a_distance

import common.vision.datasets as datasets
from common.vision.transforms import ResizeImage

sys.path.append('.')

import types
import ssl
ssl._create_default_https_context = ssl._create_unverified_context



In [3]:
import os
from typing import Optional, Tuple, Any
from common.vision.datasets.imagelist import ImageList

class TDS(ImageList):
    image_list = {
        "train": "image_list/tds_train.txt",
        "test": "image_list/tds_test.txt",
        "val": "image_list/tds_val.txt"
    }
    CLASSES = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

    def __init__(self, root, mode="L", split='train', download: Optional[bool] = True, **kwargs):
        assert split in ['train', 'test', 'val']
        data_list_file = os.path.join(root, self.image_list[split])

        assert mode in ['L', 'RGB']
        self.mode = mode
        super(TDS, self).__init__(root, TDS.CLASSES, data_list_file=data_list_file, **kwargs)

    def __getitem__(self, index: int) -> Tuple[Any, int]:
        """
        Args:
            index (int): Index
        return (tuple): (image, target) where target is index of the target class.
        """
        path, target = self.samples[index]
        img = self.loader(path).convert(self.mode)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None and target is not None:
            target = self.target_transform(target)
        return img, target

    @classmethod
    def get_classes(self):
        return TDS.CLASSES

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
device

device(type='cuda')

In [6]:
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
          model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
          
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses = AverageMeter('Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(args.iters_per_epoch,
                            [batch_time, data_time, losses, cls_accs, domain_accs],
                            prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = domain_adv(f_s, f_t)
        domain_acc = domain_adv.domain_discriminator_accuracy
        loss = cls_loss + transfer_loss * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

In [7]:
def get_train_transform(random_horizontal_flip=True, random_color_jitter=False,
                        resize_size=224, norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)):
    """
    resizing mode:
        - default: resize the image to 256 and take a random resized crop of size 224;
        - cen.crop: resize the image to 256 and take the center crop of size 224;
        - res: resize the image to 224;
    """
    transforms = [ResizeImage(resize_size)]
    
    if random_horizontal_flip:
        transforms.append(T.RandomHorizontalFlip())
    if random_color_jitter:
        transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5))
        
    transforms.extend([
        T.ToTensor(),
        T.Normalize(mean=norm_mean, std=norm_std)
    ])
    
    return T.Compose(transforms)

def get_val_transform(resize_size=224, norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)):
    """
    resizing mode:
        - default: resize the image to 256 and take the center crop of size 224;
        - res.: resize the image to 224
    """
    return T.Compose([
        ResizeImage(resize_size),
        T.ToTensor(),
        T.Normalize(mean=norm_mean, std=norm_std)
    ])

def get_dataset(dataset_name, root, train_source_transform, val_transform, train_target_transform=None):
    if train_target_transform is None:
        train_target_transform = train_source_transform
        
    train_source_dataset = datasets.MNIST(osp.join(root, 'MNIST'), download=True, transform=train_source_transform)
    train_target_dataset = TDS(osp.join(root, 'TDS'), split='train', download=True, transform=val_transform)
    val_dataset = TDS(osp.join(root, 'TDS'), split='val', download=True, transform=val_transform)
    test_dataset = TDS(osp.join(root, 'TDS'), split='test', download=True, transform=val_transform)
    class_names = datasets.MNIST.get_classes()
    num_classes = len(class_names)
    
    return train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, class_names

def get_model():
    import torch.nn as nn

    class LeNet(nn.Sequential):
        def __init__(self, num_classes=10):
            super(LeNet, self).__init__(
                nn.Conv2d(1, 20, kernel_size=5),
                nn.MaxPool2d(2),
                nn.ReLU(),
                nn.Conv2d(20, 50, kernel_size=5),
                nn.Dropout2d(p=0.5),
                nn.MaxPool2d(2),
                nn.ReLU(),
                nn.Flatten(start_dim=1),
                nn.Linear(50 * 4 * 4, 500),
                nn.ReLU(),
                nn.Dropout(p=0.5),
            )
            self.num_classes = num_classes
            self.out_features = 500

        def copy_head(self):
            return nn.Linear(500, self.num_classes)
    
    return LeNet()

def validate(val_loader, model, args, device) -> float:
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    if args.per_class_eval:
        confmat = ConfusionMatrix(len(args.class_names))
    else:
        confmat = None

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output = model(images)
            loss = F.cross_entropy(output, target)

            # measure accuracy and record loss
            acc1, = accuracy(output, target, topk=(1,))
            if confmat:
                confmat.update(target, output.argmax(1))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
        if confmat:
            print(confmat.format(args.class_names))

    return top1.avg

In [17]:
args = types.SimpleNamespace()
args.log = 'dann'
args.phase = 'analysis'
args.no_hflip = True
args.resize_size = 28
args.norm_mean = 0.5
args.norm_std = 0.5
args.class_names = None
args.data = "as"
args.root = ""
args.batch_size = 32
args.workers = 0
args.no_pool = True
args.bottleneck_dim = 256
args.lr = 0.01
args.lr_gamma = 0.001
args.lr_decay = 0.75
args.momentum = 0.9
args.weight_decay = 1e-3
args.epochs = 5
args.iters_per_epoch = 2500
args.trade_off = 1.
args.print_freq = 100
args.per_class_eval = True

In [9]:
logger = CompleteLogger(args.log, args.phase)
cudnn.benchmark = True

In [10]:
# Data loading code
train_transform = get_train_transform(resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std)
val_transform = get_val_transform(resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std)

In [11]:
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
        get_dataset(args.data, args.root, train_transform, val_transform)

train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
                                 shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
                                 shuffle=True, num_workers=args.workers, drop_last=True)

test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)

Downloading image_list
Downloading https://cloud.tsinghua.edu.cn/seafhttp/files/983f17fd-82d5-4919-8817-317c4f14547f/image_list.zip to MNIST\image_list.zip
233472it [00:02, 87705.53it/s]                             
Extracting MNIST\image_list.zip to MNIST
Downloading mnist_train_image
Downloading https://cloud.tsinghua.edu.cn/seafhttp/files/981517c8-807e-420a-9891-030fab5c1ab6/mnist_image.tar.gz to MNIST\mnist_image.tar.gz
34944000it [05:35, 104145.41it/s]                              
Extracting MNIST\mnist_image.tar.gz to MNIST
lr: 0.001
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Epoch: [0][   0/2500]	Time  4.32 ( 4.32)	Data  0.83 ( 0.83)	Loss   3.14 (  3.14)	Cls Acc 6.2 (6.2)	Domain Acc 42.2 (42.2)
Epoch: [0][ 100/2500]	Time  0.04 ( 0.09)	Data  0.03 ( 0.05)	Loss   1.32 (  2.17)	Cls Acc 71.9 (42.0)	Domain Acc 79.7 (72.8)
Epoch: [0][ 200/2500]	Time  0.05 ( 0.07)	Data  0.04 ( 0.04)	Loss   1.22 (  1.78)	Cls 

In [12]:
backbone = get_model()
pool_layer = nn.Identity()

classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
                             pool_layer=pool_layer, finetune=True).to(device)

domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device)

In [13]:
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(),
                args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x:  args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))

# define loss function
domain_adv = DomainAdversarialLoss(domain_discri).to(device)

In [18]:
# resume from the best checkpoint
if args.phase != 'train':
    checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
    classifier.load_state_dict(checkpoint)

# analysis the model
if args.phase == 'analysis':
    # extract features from both domains
    feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
    source_feature = collect_feature(train_source_loader, feature_extractor, device)
    target_feature = collect_feature(train_target_loader, feature_extractor, device)
    # plot t-SNE
    tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
    tsne.visualize(source_feature, target_feature, tSNE_filename)
    print("Saving t-SNE to", tSNE_filename)
    # calculate A-distance, which is a measure for distribution discrepancy
    A_distance = a_distance.calculate(source_feature, target_feature, device)
    print("A-distance =", A_distance)

if args.phase == 'test':
    acc1 = validate(test_loader, classifier, args, device)
    print(acc1)

In [15]:
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
    print("lr:", lr_scheduler.get_last_lr()[0])
    
    # train for one epoch
    train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer,
          lr_scheduler, epoch, args)

    # evaluate on validation set
    acc1 = validate(val_loader, classifier, args, device)

    # remember best acc@1 and save checkpoint
    torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
    if acc1 > best_acc1:
        shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
    best_acc1 = max(acc1, best_acc1)

In [16]:
print("best_acc1 = {:3.1f}".format(best_acc1))

# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))