<a href="https://colab.research.google.com/github/nikhilreddybilla28/acadacs/blob/main/ToAlign.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/microsoft/UDA.git

Cloning into 'UDA'...
remote: Enumerating objects: 369, done.[K
remote: Counting objects: 100% (369/369), done.[K
remote: Compressing objects: 100% (263/263), done.[K
remote: Total 369 (delta 127), reused 304 (delta 88), pack-reused 0[K
Receiving objects: 100% (369/369), 4.81 MiB | 8.83 MiB/s, done.
Resolving deltas: 100% (127/127), done.


In [None]:
import sys
sys.path.append("/content/SCDA/src")
sys.path.append("/content/SCDA/model/")
sys.path.append("/content/SCDA/src/UDA")

In [None]:
%cd UDA
!ls

/content/UDA
assets		    dataset_map  main.py    requirements.txt  trainer
CODE_OF_CONDUCT.md  datasets	 models     SECURITY.md       utils
configs		    LICENSE	 README.md  SUPPORT.md


In [None]:
! pip install timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.6.5-py3-none-any.whl (512 kB)
[K     |████████████████████████████████| 512 kB 9.5 MB/s 
Installing collected packages: timm
Successfully installed timm-0.6.5


In [None]:
!unzip '/content/drive/MyDrive/OfficeHomeDataset_10072016.zip' -d '/content/data'

In [None]:

import os
import random

from PIL import Image
from PIL import ImageFile

from torch.utils.data import Dataset
from datasets.transforms import transform_train, transform_test


ImageFile.LOAD_TRUNCATED_IMAGES = True


class CommonDataset(Dataset):
    def __init__(self, is_train: bool = True):
        self.data = []
        self.domain_id = []
        self.image_root = ''
        self.transform = transform_train() if is_train else transform_test()
        self._domains = None
        self.num_domain = 1

    @property
    def domains(self):
        return self._domains

    def __getitem__(self, index):
        # domain = random.randint(0, self.num_domain - 1)
        # path, label = self.data[domain][index]
        domain = self.domain_id[index]
        path, label = self.data[index]
        path = os.path.join(self.image_root, path)
        with Image.open(path) as image:
            image = image.convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        return {
            'image': image,
            'label': label,
            'domain': domain
        }

    def __len__(self):
        pass

In [None]:

import os

from datasets.common_dataset import CommonDataset
from datasets.reader import read_images_labels


class OfficeHome(CommonDataset):
    """
    -data_root:
     |
     |-art
     |-clipart
     |-product
     |-real_world
       |-Alarm_Clock
         |-0001.jpg
    """
    def __init__(self, data_root, domains: list, status: str = 'train', trim: int = 0):
        super().__init__(is_train=(status == 'train'))

        self._domains = ['product', 'art', 'clipart', 'real_world']

        if domains[0] not in self._domains:
            raise ValueError(f'Expected \'domain\' in {self._domains}, but got {domains[0]}')
        _status = ['train', 'val', 'test']
        if status not in _status:
            raise ValueError(f'Expected \'status\' in {_status}, but got {status}')

        self.image_root = data_root

        # read txt files
        data = read_images_labels(
            os.path.join(f'dataset_map/office_home', f'{domains[0]}.txt'),
            shuffle=(status == 'train'),
            trim=0
        )

        self.data = data
        self.domain_id = [0] * len(self.data)

    def __len__(self):
        return len(self.data)

#MODEL

In [None]:

import torch.nn as nn

from utils.torch_funcs import init_weights_fc, init_weights_fc0, init_weights_fc1, init_weights_fc2

__all__ = ['BaseModel']


class BaseModel(nn.Module):
    def __init__(self,
                 num_classes: int = 1000,
                 hda: bool = False,  # whether use hda head
                 toalign: bool = False,  # whether use toalign
                 **kwargs
                 ):
        super().__init__()
        self.num_classes = num_classes
        self._fdim = None

        # HDA
        self.hda = hda

        # toalign
        self.toalign = toalign

    def build_head(self):
        # classification head
        self.fc = nn.Linear(self.fdim, self.num_classes)
        nn.init.kaiming_normal_(self.fc.weight)
        if self.fc.bias is not None:
            nn.init.zeros_(self.fc.bias)

        # HDA head
        if self.hda:
            self.fc.apply(init_weights_fc)
            self.fc0 = nn.Linear(self._fdim, self.num_classes)
            self.fc0.apply(init_weights_fc0)
            self.fc1 = nn.Linear(self._fdim, self.num_classes)
            self.fc1.apply(init_weights_fc1)
            self.fc2 = nn.Linear(self._fdim, self.num_classes)
            self.fc2.apply(init_weights_fc2)

    @property
    def fdim(self) -> int:
        return self._fdim

    def get_backbone_parameters(self):
        return []

    def get_parameters(self):
        parameter_list = self.get_backbone_parameters()
        parameter_list.append({'params': self.fc.parameters(), 'lr_mult': 10})
        if self.hda:
            parameter_list.append({'params': self.fc0.parameters(), 'lr_mult': 10})
            parameter_list.append({'params': self.fc1.parameters(), 'lr_mult': 10})
            parameter_list.append({'params': self.fc2.parameters(), 'lr_mult': 10})

        return parameter_list

    def forward_backbone(self, x):
        """ input x --> output feature """
        return x

    def _get_toalign_weight(self, f, labels=None):
        assert labels is not None, f'labels should be asigned'
        w = self.fc.weight[labels].detach()  # [B, C]
        if self.hda:
            w0 = self.fc0.weight[labels].detach()
            w1 = self.fc1.weight[labels].detach()
            w2 = self.fc2.weight[labels].detach()
            w = w - (w0 + w1 + w2)
        eng_org = (f**2).sum(dim=1, keepdim=True)  # [B, 1]
        eng_aft = ((f*w)**2).sum(dim=1, keepdim=True)  # [B, 1]
        scalar = (eng_org / eng_aft).sqrt()
        w_pos = w * scalar

        return w_pos

    def forward(self, x, toalign=False, labels=None) -> tuple:
        """
        return: [f, y, ...]
        """
        f = self.forward_backbone(x)  # output feature [B, C]
        assert f.dim() == 2, f'Expected dim of returned features to be 2, but found {f.dim()}'

        if toalign:
            w_pos = self._get_toalign_weight(f, labels=labels)
            f_pos = f * w_pos
            y_pos = self.fc(f_pos)
            if self.hda:
                z_pos0 = self.fc0(f_pos)
                z_pos1 = self.fc1(f_pos)
                z_pos2 = self.fc2(f_pos)
                z_pos = z_pos0 + z_pos1 + z_pos2
                return f_pos, y_pos - z_pos, z_pos
            else:
                return f_pos, y_pos
        else:
            y = self.fc(f)
            if self.hda:
                z0 = self.fc0(f)
                z1 = self.fc1(f)
                z2 = self.fc2(f)
                z = z0 + z1 + z2
                return f, y - z, z
            else:
                return f, y

In [None]:
import torch.nn as nn

from utils.torch_funcs import grl_hook

__all__ = ['Discriminator']


class Discriminator(nn.Module):
    def __init__(self, in_feature: int, hidden_size: int, out_feature: int = 1):
        super(Discriminator, self).__init__()
        self.ad_layer1 = nn.Linear(in_feature, hidden_size)
        self.ad_layer2 = nn.Linear(hidden_size, hidden_size)
        self.ad_layer3 = nn.Linear(hidden_size, out_feature)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.5)
        self.dropout3 = nn.Dropout(0.5)

        self._init_params()

    def forward(self, x, coeff: float):
        x = x * 1.  # to avoid affect the grad from another pipeline to x_0
        x.register_hook(grl_hook(coeff))
        x = self.ad_layer1.forward(x)
        x = self.relu1(x)
        x = self.dropout1(x)
        x = self.ad_layer2(x)
        x = self.relu2(x)
        x = self.dropout2(x)
        y = self.ad_layer3(x)

        return y

    def _init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.normal_(m.weight, 1., 0.02)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def get_parameters(self):
        return [{'params': self.parameters(), 'lr_mult': 10}]

In [None]:

from typing import Any, Type, Union, List, Optional

import torch
import torch.nn as nn

from models.base_model import BaseModel
from utils.utils import init_from_pretrained_weights

__all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101']

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
}


def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion: int = 1

    def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None):
        super(BasicBlock, self).__init__()
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion: int = 4

    def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None):
        super(Bottleneck, self).__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(BaseModel):
    def __init__(self,
                 block: Type[Union[BasicBlock, Bottleneck]],
                 layers: List[int],
                 num_classes: int = 1000,
                 **kwargs
                 ):
        super().__init__(num_classes=num_classes, **kwargs)

        self.inplanes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.feature_layers = [self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4]

        self._fdim = 512 * block.expansion

        self._init_params()

        # head
        self.build_head()

    def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride=stride),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def _init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.normal_(m.weight, 1., 0.02)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def get_backbone_parameters(self):
        feature_layers_params = []
        for m in self.feature_layers:
            feature_layers_params += list(m.parameters())
        parameter_list = [{'params': feature_layers_params, 'lr_mult': 1}]

        return parameter_list

    def _forward_impl(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

    def forward_backbone(self, x):
        x = self._forward_impl(x)
        x = self.global_avgpool(x)
        f = torch.flatten(x, 1)
        return f


def resnet18(pretrained: bool = True, num_classes: int = 1000, **kwargs: Any) -> ResNet:
    model = ResNet(block=BasicBlock, layers=[2, 2, 2, 2], num_classes=num_classes, **kwargs)
    if pretrained:
        init_from_pretrained_weights(
            model,
            torch.hub.load_state_dict_from_url(url=model_urls['resnet18'], map_location='cpu', model_dir='downloads')
        )

    return model


def resnet34(pretrained: bool = True, num_classes: int = 1000, **kwargs: Any) -> ResNet:
    model = ResNet(block=BasicBlock, layers=[3, 4, 6, 3], num_classes=num_classes, **kwargs)
    if pretrained:
        init_from_pretrained_weights(
            model,
            torch.hub.load_state_dict_from_url(url=model_urls['resnet34'], map_location='cpu', model_dir='downloads')
        )

    return model


def resnet50(pretrained: bool = True, num_classes: int = 1000, **kwargs: Any) -> ResNet:
    model = ResNet(block=Bottleneck, layers=[3, 4, 6, 3], num_classes=num_classes, **kwargs)
    if pretrained:
        init_from_pretrained_weights(
            model,
            torch.hub.load_state_dict_from_url(url=model_urls['resnet50'], map_location='cpu', model_dir='downloads')
        )

    return model


def resnet101(pretrained: bool = True, num_classes: int = 1000, **kwargs: Any) -> ResNet:
    model = ResNet(block=Bottleneck, layers=[3, 4, 23, 3], num_classes=num_classes, **kwargs)
    if pretrained:
        init_from_pretrained_weights(
            model,
            torch.hub.load_state_dict_from_url(url=model_urls['resnet101'], map_location='cpu', model_dir='downloads')
        )

    return model


# TRAINER

In [None]:

import os
import logging

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from timm.utils import accuracy, AverageMeter

from utils.utils import save_model, write_log
from utils.lr_scheduler import inv_lr_scheduler
from datasets import *
from models import *


class BaseTrainer:
    def __init__(self, cfg):
        self.cfg = cfg

        logging.info(f'--> trainer: {self.__class__.__name__}')

        self.setup()
        self.build_datasets()
        self.build_models()
        self.resume_from_ckpt()

    def setup(self):
        self.start_ite = 0
        self.ite = 0
        self.best_acc = 0.
        self.tb_writer = SummaryWriter(self.cfg.TRAIN.OUTPUT_TB)

    def build_datasets(self):
        logging.info(f'--> building dataset from: {self.cfg.DATASET.NAME}')
        self.dataset_loaders = {}

        # dataset loaders
        if self.cfg.DATASET.NAME == 'office_home':
            dataset = OfficeHome
        elif self.cfg.DATASET.NAME == 'domainnet':
            dataset = DomainNet
        else:
            raise ValueError(f'Dataset {self.cfg.DATASET.NAME} not found')

        self.dataset_loaders['source_train'] = DataLoader(
            dataset(self.cfg.DATASET.ROOT, self.cfg.DATASET.SOURCE, status='train'),
            batch_size=self.cfg.TRAIN.BATCH_SIZE_SOURCE,
            shuffle=True,
            num_workers=self.cfg.WORKERS,
            drop_last=True
        )
        self.dataset_loaders['source_test'] = DataLoader(
            dataset(self.cfg.DATASET.ROOT, self.cfg.DATASET.SOURCE, status='val', trim=self.cfg.DATASET.TRIM),
            batch_size=self.cfg.TRAIN.BATCH_SIZE_TEST,
            shuffle=False,
            num_workers=self.cfg.WORKERS,
            drop_last=False
        )
        self.dataset_loaders['target_train'] = DataLoader(
            dataset(self.cfg.DATASET.ROOT, self.cfg.DATASET.TARGET, status='train'),
            batch_size=self.cfg.TRAIN.BATCH_SIZE_TARGET,
            shuffle=True,
            num_workers=self.cfg.WORKERS,
            drop_last=True
        )
        self.dataset_loaders['target_test'] = DataLoader(
            dataset(self.cfg.DATASET.ROOT, self.cfg.DATASET.TARGET, status='test'),
            batch_size=self.cfg.TRAIN.BATCH_SIZE_TEST,
            shuffle=False,
            num_workers=self.cfg.WORKERS,
            drop_last=False
        )
        self.len_src = len(self.dataset_loaders['source_train'])
        self.len_tar = len(self.dataset_loaders['target_train'])
        logging.info(f'    source {self.cfg.DATASET.SOURCE}: {self.len_src}'
                     f'/{len(self.dataset_loaders["source_test"])}')
        logging.info(f'    target {self.cfg.DATASET.TARGET}: {self.len_tar}'
                     f'/{len(self.dataset_loaders["target_test"])}')

    def build_models(self):
        logging.info(f'--> building models: {self.cfg.MODEL.BASENET}')
        self.base_net = self.build_base_models()
        self.registed_models = {'base_net': self.base_net}
        parameter_list = self.base_net.get_parameters()
        self.model_parameters()
        self.build_optim(parameter_list)

    def build_base_models(self):
        basenet_name = self.cfg.MODEL.BASENET
        kwargs = {
            'pretrained': self.cfg.MODEL.PRETRAIN,
            'num_classes': self.cfg.DATASET.NUM_CLASSES,
        }

        basenet = eval(basenet_name)(**kwargs).cuda()

        return basenet

    def model_parameters(self):
        for k, v in self.registed_models.items():
            logging.info(f'    {k} paras: '
                         f'{(sum(p.numel() for p in v.parameters()) / 1e6):.2f}M')

    def build_optim(self, parameter_list: list):
        self.optimizer = optim.SGD(
            parameter_list,
            lr=self.cfg.TRAIN.LR,
            momentum=self.cfg.OPTIM.MOMENTUM,
            weight_decay=self.cfg.OPTIM.WEIGHT_DECAY,
            nesterov=True
        )
        self.lr_scheduler = inv_lr_scheduler

    def resume_from_ckpt(self):
        last_ckpt = os.path.join(self.cfg.TRAIN.OUTPUT_CKPT, 'models-last.pt')
        if os.path.exists(last_ckpt):
            ckpt = torch.load(last_ckpt)
            for k, v in self.registed_models.items():
                v.load_state_dict(ckpt[k])
            self.optimizer.load_state_dict(ckpt['optimizer'])
            self.start_ite = ckpt['ite']
            self.best_acc = ckpt['best_acc']
            logging.info(f'> loading ckpt from {last_ckpt} | ite: {self.start_ite} | best_acc: {self.best_acc:.3f}')
        else:
            logging.info('--> training from scratch')

    def train(self):
        # start training
        for _, v in self.registed_models.items():
            v.train()
        for self.ite in range(self.start_ite, self.cfg.TRAIN.TTL_ITE):
            # test
            if self.ite % self.cfg.TRAIN.TEST_FREQ == self.cfg.TRAIN.TEST_FREQ - 1 and self.ite != self.start_ite:
                self.base_net.eval()
                self.test()
                self.base_net.train()

            self.current_lr = self.lr_scheduler(
                self.optimizer,
                ite_rate=self.ite / self.cfg.TRAIN.TTL_ITE * self.cfg.METHOD.HDA.LR_MULT,
                lr=self.cfg.TRAIN.LR,
            )

            # dataloader
            if self.ite % self.len_src == 0 or self.ite == self.start_ite:
                iter_src = iter(self.dataset_loaders['source_train'])
            if self.ite % self.len_tar == 0 or self.ite == self.start_ite:
                iter_tar = iter(self.dataset_loaders['target_train'])

            # forward one iteration
            data_src = iter_src.__next__()
            data_tar = iter_tar.__next__()
            self.one_step(data_src, data_tar)
            if self.ite % self.cfg.TRAIN.SAVE_FREQ == 0 and self.ite != 0:
                self.save_model(is_best=False, snap=True)

    def one_step(self, data_src, data_tar):
        inputs_src, labels_src = data_src['image'].cuda(), data_src['label'].cuda()

        outputs_all_src = self.base_net(inputs_src)  # [f, y]

        loss_cls_src = F.cross_entropy(outputs_all_src[1], labels_src)

        loss_ttl = loss_cls_src

        # update
        self.step(loss_ttl)

        # display
        if self.ite % self.cfg.TRAIN.PRINT_FREQ == 0:
            self.display([
                f'l_cls_src: {loss_cls_src.item():.3f}',
                f'l_ttl: {loss_ttl.item():.3f}',
                f'best_acc: {self.best_acc:.3f}',
            ])
            # tensorboard
            self.update_tb({
                'l_cls_src': loss_cls_src.item(),
                'l_ttl': loss_ttl.item(),
            })

    def display(self, data: list):
        log_str = f'I:  {self.ite}/{self.cfg.TRAIN.TTL_ITE} | lr: {self.current_lr:.5f} '
        # update
        for _str in data:
            log_str += '| {} '.format(_str)
        logging.info(log_str)

    def update_tb(self, data: dict):
        for k, v in data.items():
            self.tb_writer.add_scalar(k, v, self.ite)

    def step(self, loss_ttl):
        self.optimizer.zero_grad()
        loss_ttl.backward()
        self.optimizer.step()

    def test(self):
        logging.info('--> testing on source_test')
        src_acc = self.test_func(self.dataset_loaders['source_test'], self.base_net)
        logging.info('--> testing on target_test')
        tar_acc = self.test_func(self.dataset_loaders['target_test'], self.base_net)
        is_best = False
        if tar_acc > self.best_acc:
            self.best_acc = tar_acc
            is_best = True

        # display
        log_str = f'I:  {self.ite}/{self.cfg.TRAIN.TTL_ITE} | src_acc: {src_acc:.3f} | tar_acc: {tar_acc:.3f} | ' \
                  f'best_acc: {self.best_acc:.3f}'
        logging.info(log_str)

        # save results
        log_dict = {
            'I': self.ite,
            'src_acc': src_acc,
            'tar_acc': tar_acc,
            'best_acc': self.best_acc
        }
        write_log(self.cfg.TRAIN.OUTPUT_RESFILE, log_dict)

        # tensorboard
        self.tb_writer.add_scalar('tar_acc', tar_acc, self.ite)
        self.tb_writer.add_scalar('src_acc', src_acc, self.ite)

        self.save_model(is_best=is_best)

    def test_func(self, loader, model):
        with torch.no_grad():
            iter_test = iter(loader)
            print_freq = max(len(loader) // 5, self.cfg.TRAIN.PRINT_FREQ)
            accs = AverageMeter()
            for i in range(len(loader)):
                if i % print_freq == print_freq - 1:
                    logging.info('    I:  {}/{} | acc: {:.3f}'.format(i, len(loader), accs.avg))
                data = iter_test.__next__()
                inputs, labels = data['image'].cuda(), data['label'].cuda()
                outputs_all = model(inputs)  # [f, y, ...]
                outputs = outputs_all[1]

                acc = accuracy(outputs, labels)[0]
                accs.update(acc.item(), labels.size(0))

        return accs.avg

    def save_model(self, is_best=False, snap=False):
        data_dict = {
            'optimizer': self.optimizer.state_dict(),
            'ite': self.ite,
            'best_acc': self.best_acc
        }
        for k, v in self.registed_models.items():
            data_dict.update({k: v.state_dict()})
        save_model(self.cfg.TRAIN.OUTPUT_CKPT, data_dict=data_dict, ite=self.ite, is_best=is_best, snap=snap)


In [None]:

import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils.torch_funcs import entropy_func
from trainer.base_trainer import BaseTrainer
from models.discriminator import *
from utils.loss import d_align_uda
from utils.utils import get_coeff


class DANN(BaseTrainer):
    def __init__(self, cfg):
        super(DANN, self).__init__(cfg)

    def build_models(self):
        logging.info(f'--> building models: {self.cfg.MODEL.BASENET}')
        # backbone
        self.base_net = self.build_base_models()
        fdim = self.base_net.fdim
        # discriminator
        self.d_net = eval(self.cfg.MODEL.DNET)(
            in_feature=fdim,
            hidden_size=self.cfg.MODEL.D_HIDDEN_SIZE,
            out_feature=self.cfg.MODEL.D_OUTDIM
        ).cuda()

        self.registed_models = {'base_net': self.base_net, 'd_net': self.d_net}
        self.model_parameters()
        parameter_list = self.base_net.get_parameters() + self.d_net.get_parameters()
        self.build_optim(parameter_list)

    def one_step(self, data_src, data_tar):
        inputs_src, labels_src = data_src['image'].cuda(), data_src['label'].cuda()
        inputs_tar, labels_tar = data_tar['image'].cuda(), data_tar['label'].cuda()

        outputs_all_src = self.base_net(inputs_src)
        outputs_all_tar = self.base_net(inputs_tar)

        features_all = torch.cat((outputs_all_src[0], outputs_all_tar[0]), dim=0)
        logits_all = torch.cat((outputs_all_src[1], outputs_all_tar[1]), dim=0)
        softmax_all = nn.Softmax(dim=1)(logits_all)

        ent_tar = entropy_func(nn.Softmax(dim=1)(outputs_all_tar[1].data)).mean()

        # classificaiton
        loss_cls_src = F.cross_entropy(outputs_all_src[1], labels_src)
        loss_cls_tar = F.cross_entropy(outputs_all_tar[1].data, labels_tar)

        # domain alignment
        loss_alg = d_align_uda(
            softmax_all, features_all, self.d_net,
            coeff=get_coeff(self.ite, max_iter=self.cfg.TRAIN.TTL_ITE), ent=self.cfg.METHOD.ENT
        )

        loss_ttl = loss_cls_src + loss_alg * self.cfg.METHOD.W_ALG

        # update
        self.step(loss_ttl)

        # display
        if self.ite % self.cfg.TRAIN.PRINT_FREQ == 0:
            self.display([
                f'l_cls_src: {loss_cls_src.item():.3f}',
                f'l_cls_tar: {loss_cls_tar.item():.3f}',
                f'l_alg: {loss_alg.item():.3f}',
                f'l_ttl: {loss_ttl.item():.3f}',
                f'ent_tar: {ent_tar.item():.3f}',
                f'best_acc: {self.best_acc:.3f}',
            ])
            # tensorboard
            self.update_tb({
                'l_cls_src': loss_cls_src.item(),
                'l_cls_tar': loss_cls_tar.item(),
                'l_alg': loss_alg.item(),
                'l_ttl': loss_ttl.item(),
                'ent_tar': ent_tar.item(),
            })


In [None]:
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils.torch_funcs import entropy_func
import utils.loss as loss
from trainer.da.dann import DANN
from models import *
from utils.utils import get_coeff


__all__ = ['ToAlign']


class ToAlign(DANN):
    def __init__(self, cfg):
        super(ToAlign, self).__init__(cfg)

    def build_base_models(self):
        basenet_name = self.cfg.MODEL.BASENET
        kwargs = {
            'pretrained': self.cfg.MODEL.PRETRAIN,
            'num_classes': self.cfg.DATASET.NUM_CLASSES,
            'hda': True,
            'toalign': True,
        }

        basenet = eval(basenet_name)(**kwargs).cuda()

        return basenet

    def build_models(self):
        logging.info(f'--> building models: {self.cfg.MODEL.BASENET}')
        # backbone
        self.base_net = self.build_base_models()
        # discriminator
        self.d_net = eval(self.cfg.MODEL.DNET)(
            in_feature=self.cfg.DATASET.NUM_CLASSES,
            hidden_size=self.cfg.MODEL.D_HIDDEN_SIZE,
            out_feature=self.cfg.MODEL.D_OUTDIM
        ).cuda()

        self.registed_models = {'base_net': self.base_net, 'd_net': self.d_net}
        self.model_parameters()
        parameter_list = self.base_net.get_parameters() + self.d_net.get_parameters()
        self.build_optim(parameter_list)

    def one_step(self, data_src, data_tar):
        inputs_src, labels_src = data_src['image'].cuda(), data_src['label'].cuda()
        inputs_tar, labels_tar = data_tar['image'].cuda(), data_tar['label'].cuda()

        # --------- classification --------------
        outputs_all_src = self.base_net(inputs_src)  # [f, y, z]
        assert len(outputs_all_src) == 3, \
            f'Expected return with size 3, but got {len(outputs_all_src)}'
        loss_cls_src = F.cross_entropy(outputs_all_src[1], labels_src)
        focals_src = outputs_all_src[-1]

        # --------- alignment --------------
        outputs_all_src = self.base_net(inputs_src, toalign=True, labels=labels_src)  # [f_p, y_p, z_p]
        outputs_all_tar = self.base_net(inputs_tar)  # [f, y, z]
        assert len(outputs_all_src) == 3 and len(outputs_all_tar) == 3, \
            f'Expected return with size 3, but got {len(outputs_all_src)}'
        focals_tar = outputs_all_tar[-1]

        logits_all = torch.cat((outputs_all_src[1], outputs_all_tar[1]), dim=0)
        softmax_all = nn.Softmax(dim=1)(logits_all)
        focals_all = torch.cat((focals_src, focals_tar), dim=0)

        ent_tar = entropy_func(nn.Softmax(dim=1)(outputs_all_tar[1].data)).mean()

        # classificaiton loss
        loss_cls_tar = F.cross_entropy(outputs_all_tar[1].data, labels_tar)

        # domain alignment
        if self.cfg.TASK.NAME == 'UDA':
            loss_alg = loss.d_align_uda(
                softmax_output=softmax_all, d_net=self.d_net,
                coeff=get_coeff(self.ite, max_iter=self.cfg.TRAIN.TTL_ITE), ent=self.cfg.METHOD.ENT
            )
        elif self.cfg.TASK.NAME == 'MSDA':
            loss_alg = loss.d_align_msda(
                softmax_output=softmax_all, d_net=self.d_net,
                coeff=get_coeff(self.ite, max_iter=self.cfg.TRAIN.TTL_ITE), ent=self.cfg.METHOD.ENT,
                batchsizes=[inputs_src.shape[0], inputs_tar.shape[0]]
            )

        # hda
        loss_hda = focals_all.abs().mean()

        loss_ttl = loss_cls_src + loss_alg * self.cfg.METHOD.W_ALG + loss_hda

        # update
        self.step(loss_ttl)

        # display
        if self.ite % self.cfg.TRAIN.PRINT_FREQ == 0:
            self.display([
                f'l_cls_src: {loss_cls_src.item():.3f}',
                f'l_cls_tar: {loss_cls_tar.item():.3f}',
                f'l_alg: {loss_alg.item():.3f}',
                f'l_hda: {loss_hda.item():.3f}',
                f'l_ttl: {loss_ttl.item():.3f}',
                f'ent_tar: {ent_tar.item():.3f}',
                f'best_acc: {self.best_acc:.3f}',
            ])
            # tensorboard
            self.update_tb({
                'l_cls_src': loss_cls_src.item(),
                'l_cls_tar': loss_cls_tar.item(),
                'l_alg': loss_alg.item(),
                'l_hda': loss_hda.item(),
                'l_ttl': loss_ttl.item(),
                'ent_tar': ent_tar.item(),
            })


# MAIN

In [None]:
from argparse import Namespace
cfg = Namespace()

cfg.MODEL = Namespace()
cfg.DATASET = Namespace()
cfg.TRAIN = Namespace()
cfg.METHOD = Namespace()
cfg.OPTIM = Namespace()

cfg.METHOD.HDA = Namespace()

cfg.WORKERS = 5

cfg.MODEL.BASENET = "EmbedNN"
cfg.MODEL.DNET = 'Discriminator'

cfg.MODEL.EMBED_INPUT = [3, 131, 4, 483, 103, 5, 106, 4]
cfg.MODEL.EMBED_DIM = [1, 3, 1, 4, 3, 1, 3, 1]
cfg.MODEL.NUM_DIM = 43

cfg.MODEL.D_HIDDEN_SIZE = 32
cfg.MODEL.D_OUTDIM = 1

cfg.DATASET.NUM_CLASSES = 2

cfg.TRAIN.TTL_ITE = 10000
cfg.TRAIN.PRINT_FREQ = 16
cfg.TRAIN.OUTPUT_CKPT = "./"

cfg.METHOD.ENT = True
cfg.METHOD.W_ALG = 1.0
cfg.METHOD.HDA.LR_MULT = 1.0



cfg.TRAIN.BATCH_SIZE = 4096

cfg.TRAIN.TEST_SIZE = 0.25
cfg.TRAIN.TEST_FREQ = 50
cfg.TRAIN.CATE_INDEX = 8

cfg.OPTIM.MOMENTUM = 0.9
cfg.OPTIM.WEIGHT_DECAY = 5e-4

In [None]:
! pip install yacs

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting yacs
  Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Installing collected packages: yacs
Successfully installed yacs-0.1.8


In [None]:
# --------------------------------------------------------
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License
# --------------------------------------------------------

import os

from yacs.config import CfgNode as CN


_C = CN()
_C.SEED = 123
_C.WORKERS = 8
_C.TRAINER = 'Trainer'


# tasks
_C.TASK = CN()
_C.TASK.NAME = 'UDA'
_C.TASK.SSDA_SHOT = 1

# ================= training ====================
_C.TRAIN = CN()
_C.TRAIN.TEST_FREQ = 500
_C.TRAIN.PRINT_FREQ = 50
_C.TRAIN.SAVE_FREQ = 5000
_C.TRAIN.TTL_ITE = 8000

_C.TRAIN.BATCH_SIZE_SOURCE = 36
_C.TRAIN.BATCH_SIZE_TARGET = 36
_C.TRAIN.BATCH_SIZE_TEST = 36
_C.TRAIN.LR = 0.001

_C.TRAIN.OUTPUT_ROOT = 'temp'
_C.TRAIN.OUTPUT_DIR = ''
_C.TRAIN.OUTPUT_LOG = 'log'
_C.TRAIN.OUTPUT_TB = 'tensorboard'
_C.TRAIN.OUTPUT_CKPT = 'ckpt'
_C.TRAIN.OUTPUT_RESFILE = 'log.txt'

# ================= models ====================
_C.OPTIM = CN()
_C.OPTIM.WEIGHT_DECAY = 5e-4
_C.OPTIM.MOMENTUM = 0.9

# ================= models ====================
_C.MODEL = CN()
_C.MODEL.PRETRAIN = True
_C.MODEL.BASENET = 'resent50'
_C.MODEL.BASENET_DOMAIN_EBD = False  # for domain embedding for transformer
_C.MODEL.DNET = 'Discriminator'
_C.MODEL.D_INDIM = 0
_C.MODEL.D_OUTDIM = 1
_C.MODEL.D_HIDDEN_SIZE = 1024
_C.MODEL.D_WGAN_CLIP = 0.01
_C.MODEL.VIT_DPR = 0.1
_C.MODEL.VIT_USE_CLS_TOKEN = True
_C.MODEL.VIT_PRETRAIN_EXLD = []
# extra layer
_C.MODEL.EXT_LAYER = False
_C.MODEL.EXT_NUM_TOKENS = 100
_C.MODEL.EXT_NUM_LAYERS = 1
_C.MODEL.EXT_NUM_HEADS = 24
_C.MODEL.EXT_LR = 10.
_C.MODEL.EXT_DPR = 0.1
_C.MODEL.EXT_SKIP = True
_C.MODEL.EXT_FEATURE = 768

# ================= dataset ====================
_C.DATASET = CN()
_C.DATASET.ROOT = ''
_C.DATASET.NUM_CLASSES = 10
_C.DATASET.NAME = 'office_home'
_C.DATASET.SOURCE = []
_C.DATASET.TARGET = []
_C.DATASET.TRIM = 0

# ================= method ====================
_C.METHOD = CN()
_C.METHOD.W_ALG = 1.0
_C.METHOD.ENT = False

# HDA
_C.METHOD.HDA = CN()
_C.METHOD.HDA.W_HDA = 1.0
_C.METHOD.HDA.LR_MULT = 1.0  # set as 5.0 to tune the lr_schedule to follow the setting of original HDA


def get_default_and_update_cfg(args):
    cfg = _C.clone()
    cfg.merge_from_file(args.cfg)
    if args.opts:
        cfg.merge_from_list(args.opts)

    #
    cfg.SEED = args.seed

    #
    if args.data_root:
        cfg.DATASET.ROOT = args.data_root

    # dataset maps
    maps = {
        'office_home': {
            'p': 'product',
            'a': 'art',
            'c': 'clipart',
            'r': 'real_world'
        },
        'domainnet': {
            'c': 'clipart',
            'i': 'infograph',
            'p': 'painting',
            'q': 'quickdraw',
            'r': 'real',
            's': 'sketch',
        },
    }

    # MSDA
    if cfg.TASK.NAME == 'MSDA':
        args.source = [k for k in maps[cfg.DATASET.NAME].keys()]
        args.source.remove(args.target[0])

    # source & target
    cfg.DATASET.SOURCE = [maps[cfg.DATASET.NAME][_d] if _d in maps[cfg.DATASET.NAME].keys() else _
                          for _d in args.source]
    cfg.DATASET.TARGET = [maps[cfg.DATASET.NAME][_d] if _d in maps[cfg.DATASET.NAME].keys() else _
                          for _d in args.target]

    # class
    if cfg.DATASET.NAME == 'office_home':
        cfg.DATASET.NUM_CLASSES = 65
    elif cfg.DATASET.NAME == 'office':
        cfg.DATASET.NUM_CLASSES = 31
    elif cfg.DATASET.NAME == 'visda-2017':
        cfg.DATASET.NUM_CLASSES = 12
    elif cfg.DATASET.NAME == 'domainnet' or cfg.DATASET.NAME == 'uda_domainnet':
        cfg.DATASET.NUM_CLASSES = 345
    elif cfg.DATASET.NAME == 'ssda-domainnet':
        cfg.DATASET.NUM_CLASSES = 126
    else:
        raise NotImplementedError(f'cfg.DATASET.NAME: {cfg.DATASET.NAME} not imeplemented')

    # output
    if args.output_root:
        cfg.TRAIN.OUTPUT_ROOT = args.output_root
    if args.output_dir:
        cfg.TRAIN.OUTPUT_DIR = args.output_dir
    else:
        cfg.TRAIN.OUTPUT_DIR = '_'.join(cfg.DATASET.SOURCE) + '2' + '_'.join(cfg.DATASET.TARGET) + '_' + str(args.seed)

    #
    cfg.TRAIN.OUTPUT_CKPT = os.path.join(cfg.TRAIN.OUTPUT_ROOT, 'ckpt', cfg.TRAIN.OUTPUT_DIR)
    cfg.TRAIN.OUTPUT_LOG = os.path.join(cfg.TRAIN.OUTPUT_ROOT, 'log', cfg.TRAIN.OUTPUT_DIR)
    cfg.TRAIN.OUTPUT_TB = os.path.join(cfg.TRAIN.OUTPUT_ROOT, 'tensorboard', cfg.TRAIN.OUTPUT_DIR)
    os.makedirs(cfg.TRAIN.OUTPUT_CKPT, exist_ok=True)
    os.makedirs(cfg.TRAIN.OUTPUT_LOG, exist_ok=True)
    os.makedirs(cfg.TRAIN.OUTPUT_TB, exist_ok=True)
    cfg.TRAIN.OUTPUT_RESFILE = os.path.join(cfg.TRAIN.OUTPUT_LOG, 'log.txt')

    return cfg


def check_cfg(cfg):
    # OUTPUT
    cfg.TRAIN.OUTPUT_CKPT = os.path.join(cfg.TRAIN.OUTPUT_ROOT, 'ckpt', cfg.TRAIN.OUTPUT_DIR)
    cfg.TRAIN.OUTPUT_LOG = os.path.join(cfg.TRAIN.OUTPUT_ROOT, 'log', cfg.TRAIN.OUTPUT_DIR)
    cfg.TRAIN.OUTPUT_TB = os.path.join(cfg.TRAIN.OUTPUT_ROOT, 'tensorboard', cfg.TRAIN.OUTPUT_DIR)
    os.makedirs(cfg.TRAIN.OUTPUT_CKPT, exist_ok=True)
    os.makedirs(cfg.TRAIN.OUTPUT_LOG, exist_ok=True)
    os.makedirs(cfg.TRAIN.OUTPUT_TB, exist_ok=True)
    cfg.TRAIN.OUTPUT_RESFILE = os.path.join(cfg.TRAIN.OUTPUT_LOG, 'log.txt')

    # dataset
    maps = {
        'office_home': {
            'p': 'product',
            'a': 'art',
            'c': 'clipart',
            'r': 'real_world'
        }
    }
    cfg.DATASET.SOURCE = [maps[cfg.DATASET.NAME][_d] if _d in maps[cfg.DATASET.NAME].keys() else _
                          for _d in cfg.DATASET.SOURCE]
    cfg.DATASET.TARGET = [maps[cfg.DATASET.NAME][_d] if _d in maps[cfg.DATASET.NAME].keys() else _
                          for _d in cfg.DATASET.TARGET]

    datapath_list = {
        'office-home': {
            'p': ['Product.txt', 'Product.txt'],
            'a': ['Art.txt', 'Art.txt'],
            'c': ['Clipart.txt', 'Clipart.txt'],
            'r': ['Real_World.txt', 'Real_World.txt']
        },
        'uda_domainnet': {
            'c': ['clipart_train.txt', 'clipart_test.txt'],
            'i': ['infograph_train.txt', 'infograph_test.txt'],
            'p': ['painting_train.txt', 'painting_test.txt'],
            'q': ['quickdraw_train.txt', 'quickdraw_test.txt'],
            'r': ['real_train.txt', 'real_test.txt'],
            's': ['sketch_train.txt', 'sketch_test.txt']
        }
    }

    # class
    if cfg.DATASET.NAME == 'office_home':
        cfg.DATASET.NUM_CLASSES = 65
    elif cfg.DATASET.NAME == 'domainnet' or cfg.DATASET.NAME == 'uda_domainnet':
        cfg.DATASET.NUM_CLASSES = 345
    else:
        raise NotImplementedError(f'cfg.DATASET.NAME: {cfg.DATASET.NAME} not imeplemented')

    return cfg


In [None]:
# source and target domains can be defined by "--source" and "--target"
root = '/content/data/office_home' 
s = 'a'
t = 'r'
op = 'exp'
! python main.py configs/uda_office_home_toalign.yaml --data_root '/content/data/office_home' --source 'a' --target 'r' --output_root 'exp'

{
    "cfg": "configs/uda_office_home_toalign.yaml",
    "seed": 123,
    "source": [
        "a"
    ],
    "target": [
        "r"
    ],
    "output_root": "exp",
    "output_dir": null,
    "data_root": "/content/data/office_home",
    "opts": null
}
DATASET:
    NAME: office_home
    NUM_CLASSES: 65
    ROOT: /content/data/office_home
    SOURCE: [art]
    TARGET: [real_world]
    TRIM: 0
METHOD:
    ENT: true
    HDA: {LR_MULT: 1.0, W_HDA: 1.0}
    W_ALG: 1.0
MODEL:
    BASENET: resnet50
    BASENET_DOMAIN_EBD: false
    DNET: Discriminator
    D_HIDDEN_SIZE: 1024
    D_INDIM: 1024
    D_OUTDIM: 1
    D_WGAN_CLIP: 0.01
    EXT_DPR: 0.1
    EXT_FEATURE: 768
    EXT_LAYER: false
    EXT_LR: 10.0
    EXT_NUM_HEADS: 24
    EXT_NUM_LAYERS: 1
    EXT_NUM_TOKENS: 100
    EXT_SKIP: true
    PRETRAIN: true
    VIT_DPR: 0.1
    VIT_PRETRAIN_EXLD: []
    VIT_USE_CLS_TOKEN: true
OPTIM: {MOMENTUM: 0.9, WEIGHT_DECAY: 0.0005}
SEED: 123
TASK: {NAME: UDA, SSDA_SHOT: 1}
TRAIN: {BATCH_SIZE_SOURCE: 

In [None]:

import json
import argparse
from argparse import Namespace

import torch.backends.cudnn as cudnn
import easydict

from configs.defaults import get_default_and_update_cfg
from utils.utils import create_logger, set_seed
from trainer import *


def parse_args():
    
    '''parser = argparse.ArgumentParser()
    
    parser.add_argument('cfg',      default='configs/test.yaml', type=str)
    parser.add_argument('--seed',   default=123, type=int)
    parser.add_argument('--source', default='a', nargs='+', help='source domain names')
    parser.add_argument('--target', default='r', nargs='+', help='target domain names')
    parser.add_argument('--output_root', default="exp", type=str, help='output root path')
    parser.add_argument('--output_dir',  default='op', type=str, help='output path, subdir under output_root')
    parser.add_argument('--data_root',   default='data/office_home', type=str, help='path to dataset root')
    parser.add_argument('--opts',   default=None, nargs=argparse.REMAINDER)'''
    args = easydict.EasyDict({
        "cfg": 'configs/test.yaml',
        "seed": 123,
        "source": 'r',
        "target": 'a',
        "output_dir": 'od',
        "outpuy_root": 'exp',
        "data_root": 'data/office_home'
})
    #args = parser.parse_args()
    #args, unknown = parser.parse_known_args()
    assert os.path.isfile(args.cfg), 'cfg file: {} not found'.format(args.cfg)

    return args


def main():
    args = parse_args()
    cfg = get_default_and_update_cfg(args)
    cfg.freeze()

    # seed
    set_seed(cfg.SEED)

    cudnn.deterministic = True
    cudnn.benchmark = False

    # logger
    logger = create_logger(cfg.TRAIN.OUTPUT_LOG)

    logger.info('======================= args =======================\n' + json.dumps(vars(args), indent=4))
    logger.info('======================= cfg =======================\n' + cfg.dump(indent=4))

    trainer = eval(cfg.TRAINER)(cfg)
    trainer.train()

main()