## Imports and Preliminary

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torchvision import models
from torch import Tensor

## Resnet (a.k.a. Feature Extractor)

In [2]:
import math
import torch.utils.model_zoo as model_zoo

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
}


def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(torch.nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = torch.nn.BatchNorm2d(planes)
        self.relu = torch.nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = torch.nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(torch.nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = torch.nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(planes)
        self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = torch.nn.BatchNorm2d(planes)
        self.conv3 = torch.nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = torch.nn.BatchNorm2d(planes * 4)
        self.relu = torch.nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(torch.nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(64)
        self.relu = torch.nn.ReLU(inplace=True)
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = torch.nn.AvgPool2d(7)
        self.fc = torch.nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, torch.nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = torch.nn.Sequential(
                torch.nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                torch.nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return torch.nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def resnet18(args, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=args.num_classes)
    if args.pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    # modify the structure of the model.
    num_of_feature_map = model.fc.in_features
    model.fc = torch.nn.Linear(num_of_feature_map, args.num_classes * 2)
    # model.fc.weight.data.normal_(0.0, 0.02)
    # model.fc.bias.data.normal_(0)
    return model

class TestArgs:
    def __init__(self, pretrained, num_classes):
        self.pretrained = pretrained
        self.num_classes = num_classes

In [9]:

def feature_extractor(model='resnet18', n_classes=1000):
    if model == 'resnet18': 
        model = models.resnet18(pretrained=True) # resnet18(TestArgs(True, 1000))
    elif model == 'resnet50': 
        model = models.resnet50(pretrained=True)
    else:
        raise ValueError('Unknown model')
    num_of_feature_map = model.fc.in_features
    model.fc = torch.nn.Linear(num_of_feature_map, n_classes * 2)
    return model # torch.nn.Sequential(*list(model.children())[:-1])

In [10]:
# Save the two feature extractors in txt files
with open('docs/resnet18_symnets.txt', 'w') as f:
    f.write(feature_extractor().__str__())

In [11]:
model = models.resnet18(pretrained=True)
with open('docs/resnet18_online.txt', 'w') as f:
    f.write(model.__str__())

In [10]:
with open('docs/resnet18_ours.txt', 'w') as f:
    f.write(feature_extractor().__str__())

In [11]:
# Load a batch of source images
from PIL import Image

folder = "data/Adaptiope/product_images"
source_imgs = [
    Image.open(f"{folder}/bookcase/bookcase_000.jpg").convert('RGB'),
    Image.open(f"{folder}/flat iron/flat iron_000.jpg").convert('RGB'),
    Image.open(f"{folder}/ice skates/ice skates_000.jpg").convert('RGB'),
]
# for img in source_imgs:
#     img.show()

In [12]:
from torchvision import transforms

# Ricreate the preprocessingf pipeline as reported in data_prep
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )])

# Preprocess images
source_tns = list(map(preprocess, source_imgs))

# Show preprocessed images
# to_pil_img = transforms.ToPILImage()
# for tns in source_tns:
#     to_pil_img(tns).show()
    
# Create batch of images
source_tns = [torch.unsqueeze(tns, 0) for tns in source_tns]
source_batch = torch.cat(source_tns, dim=0)
print('Batch shape =', source_batch.shape)

Batch shape = torch.Size([3, 3, 224, 224])


In [13]:
# Extract features through ResNet
resnet = feature_extractor(model='resnet18')
# resnet.eval()

# Get latent repr for each image in batch
features = resnet(source_batch)
print(features.shape)
print(features)

torch.Size([3, 2000])
tensor([[-0.3483, -0.3169, -0.3339,  ...,  0.1621, -0.5910, -1.1216],
        [ 0.5512,  0.2212,  0.0630,  ...,  0.0849, -0.2852, -0.2853],
        [-0.0732, -0.7080,  0.5453,  ...,  0.3829, -1.1158, -1.1359]],
       grad_fn=<AddmmBackward0>)


torch.Size([3, 2000])
tensor([[ 0.2031,  0.2776, -0.0273,  ...,  0.1005,  1.0534,  0.4215],
        [ 0.4832,  0.0222, -0.1837,  ...,  0.8101,  1.5281,  0.6838],
        [ 1.0008,  0.1067, -0.0567,  ...,  0.1076,  1.6767,  0.8851]],
       grad_fn=<AddmmBackward0>)

## Domain Classifier

In [14]:
n_classes = 1000

def add_threshold(prob: Tensor):
    _THRESHOLD = 1e-20
    zeros = (prob == 0)
    if torch.any(zeros): # any(any(x) for x in zeros): 
        thre_tensor = torch.zeros(zeros.shape)
        thre_tensor[zeros] = _THRESHOLD
        prob += thre_tensor
    return prob
      
def to_softmax(features: Tensor, split=False, source=True):
    prob = F.softmax(features, dim=1)
    prob = prob if not split else split_softmax(prob, source)
    return add_threshold(prob)

def split_softmax(prob: Tensor, source=True):
    return prob[:,:n_classes] if source else prob[:,n_classes:]

def cross_entropy_loss(prob: Tensor):
    return -(prob.sum(dim=1).log().mean())

# SourceDomainClassifier
print(cross_entropy_loss(to_softmax(features, split=True, source=True)))

# TargetDomainClassifier
print(cross_entropy_loss(to_softmax(features, split=True, source=False)))

tensor(0.6818, grad_fn=<NegBackward0>)
tensor(0.7046, grad_fn=<NegBackward0>)
