 ## Utils

In [2]:
import numpy as np
import torch
import math
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt


from sklearn.linear_model import LogisticRegressionCV
from torch import nn, optim
from tqdm import tqdm
from torch.optim.lr_scheduler import LambdaLR
from firelab.config import Config

In [12]:
args = Config.load('./CIFAR10_Z16_ae.yaml').to_dict()
args['device'] = 'cuda:0'

In [4]:
transform = transforms.Compose([transforms.ToTensor()])
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=False, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size = args['batch_size'],
                                           shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size = args['batch_size'] * 10,
                                          shuffle=True, drop_last=True)

In [5]:
def Initializer(layers, slope=0.2):
    for layer in layers:
        if hasattr(layer, 'weight'):
            w = layer.weight.data
            std = 1/np.sqrt((1 + slope**2) * np.prod(w.shape[:-1]))
            w.normal_(std=std)
            
        if hasattr(layer, 'bias'):
            layer.bias.data.zero_()

class Autoencoder(nn.Module):
    def __init__(self, scales, depth, latent, colors):
        super().__init__()
             
        self.encoder = self._make_network(scales, depth, latent, colors, part='encoder')
        self.decoder = self._make_network(scales, depth, latent, colors, part='decoder')
        
    def forward(self, x):
        return self.decoder(self.encoder(x))
    
    @staticmethod
    def _make_network(scales, depth, latent, colors, part=None):
        """
        input:
        part - encoder/decoder, str
        """
        activation = nn.LeakyReLU(0.01)   
        
        sub_network = []
        
        if part == 'encoder':
            sub_network += [nn.Conv2d(colors, depth, 1, padding=1)]
            
            kp = depth
            iterable = range(scales)
            transformation = nn.AvgPool2d(2)
            
        elif part == 'decoder':
            
            kp = latent
            iterable =range(scales - 1, -1, -1)
            transformation = nn.Upsample(scale_factor=2)
        
        # joint part
        for scale in range(scales):
            k = depth << scale
            sub_network.extend([nn.Conv2d(kp, k, 3, padding=1), activation,
                                transformation])
            kp = k
        
        if part == 'encoder':
            k = depth << scales
            sub_network.extend([nn.Conv2d(kp, k, 3, padding=1), activation, nn.Conv2d(k, latent, 3, padding=1)])
        
        elif part == 'decoder':
            sub_network.extend([nn.Conv2d(kp, depth, 3, padding=1), activation, nn.Conv2d(depth, colors, 3, padding=1)])
        
        Initializer(sub_network)
        return nn.Sequential(*sub_network)
    
class Critic(nn.Module):
    def __init__(self, scales, depth, latent, colors):
        super().__init__()
        
        self.flatten = nn.Flatten()
        self.critic = Autoencoder._make_network(scales, depth, latent, colors, part='encoder')
        
    def forward(self, x):
        return self.flatten(self.critic(x)).mean(dim=1)
    
    def descriptor(self, x):
        return self.critic(x)

In [7]:
def latent_space_quality(autoencoder, dataloaders, device='cpu'):
    cdn = lambda x: x.cpu().detach().numpy()
    
    train, test = dataloaders
    autoencoder = autoencoder.to(device)
    autoencoder.eval()
    
    def inference(model, loader):
        descriptor=[]
        for idx, (X, y) in enumerate(loader):
            prediction, target = cdn(model.encoder(X.to(device))), cdn(y).reshape(-1,1)
            descriptor.append(np.hstack([prediction.reshape(prediction.shape[0], -1), target]))
        return np.vstack(descriptor)
            
    descriptor_train, descriptor_test = inference(autoencoder, train), inference(autoencoder, test)
    
    autoencoder.train()
    
    return (descriptor_train, descriptor_test)

class DescriptorDataset(torch.utils.data.Dataset):
    def __init__(self, descriptor):
        self.descriptor = descriptor
        
    def __len__(self):
        return len(self.descriptor)
    
    def __getitem__(self, idx):
        obj = torch.Tensor(self.descriptor[idx])
        return (obj[:-1].float(), obj[-1])
    
    def fit_logistic_regression(self):
        lr = LogisticRegressionCV(Cs=10, cv=5, max_iter=500)
        lr.fit(descriptor[:, :-1], descriptor[:, -1])
        
        print(f'baseline acc: {[lr.scores_[k].mean() for k in lr.scores_.keys()]}')
        
    def fit_kmeans(self):
        self.kmeans = KMeans(10)
        self.kmeans.fit(self.descriptor[:, :-1])
        
        
class SingleLayer(nn.Module):
    def __init__(self, latent_dim, n_classes, dropout=0):
        super().__init__()
             
        self.FC = nn.Linear(latent_dim, n_classes)
        
        if dropout!=0:
            self.FC = nn.Sequential(nn.Dropout(dropout), self.FC)
        
    def forward(self, x):
        return self.FC(x)
    
    
def fit_FC(autoencoder, loaders, args):
    
    descr_train, descr_test = latent_space_quality(autoencoder, loaders, device=args['device'])    
    train_descriptor_dataset, test_descriptor_dataset = DescriptorDataset(descr_train), DescriptorDataset(descr_test)
    
    train_descr_loader = torch.utils.data.DataLoader(train_descriptor_dataset, batch_size = 64, shuffle=True, drop_last=False)
    test_descr_loader = torch.utils.data.DataLoader(test_descriptor_dataset, batch_size = 1000, shuffle=True, drop_last=False)
    
    criterion_CE = nn.CrossEntropyLoss()
    test_accuracy = []    
    
    latent_dim = train_descriptor_dataset[0][0].shape[0]
    
    fc_layer = SingleLayer(latent_dim=latent_dim, n_classes=10, dropout=0).to(args['device'])
    opt_fc = optim.Adam(fc_layer.parameters(), lr=1e-3, weight_decay=1e-5)
#     scheduler = LambdaLR(opt_fc, lr_lambda=lambda epoch: 0.9 ** epoch)
        
    for epoch in range(20):
        for index, (X, y) in tqdm(enumerate(train_descr_loader), total=len(train_descr_loader),
                                  leave=False, desc=f'Fit FC, Epoch: {epoch}'):

            y_hat = fc_layer(X.to(args['device']))
            loss = criterion_CE(y_hat, y.to(args['device']).long())

            opt_fc.zero_grad()
            loss.backward()
            opt_fc.step()

        # Test Step
        fc_layer.eval()

        acc = 0
        for index, (X, y) in enumerate(test_descr_loader):
            y_hat = fc_layer(X.to(args['device'])).cpu().detach()
            acc += (y == y_hat.argmax(dim=1)).sum().item()/y_hat.shape[0]
        test_accuracy.append(acc/len(test_descr_loader))

        fc_layer.train()
        
#         if epoch%2==0:
#             scheduler.step()
        
    return test_accuracy

In [17]:
def augmentation(X):
    aug_idx = np.random.randint(5)
    if aug_idx == 0:
        return torch.flip(X, (2,))
    if aug_idx == 1:
        return torch.flip(X, (3,))
    if aug_idx == 2:
        return torch.flip(X, (3, 2))
    if aug_idx == 3:
        return torch.flip(X, (1,))
    if aug_idx == 4:
        return torch.flip(X, (1, 2))

## Augmentation experiments on AE with the classificator FC on the latent space 

In [22]:
args['device'] = 'cuda:0'

In [14]:
scales = int(round(math.log(args['width'] // args['latent_width'], 2)))
autoencoder = Autoencoder(scales=scales, depth=args['depth'],
                          latent=args['latent'], colors=args['colors']
                         ).to('cuda:0')
autoencoder.load_state_dict(torch.load('Autoencoder.torch', map_location='cuda:0'))
autoencoder.to('cuda:0')   

Autoencoder(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): LeakyReLU(negative_slope=0.01)
    (3): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): LeakyReLU(negative_slope=0.01)
    (6): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.01)
    (9): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (10): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): LeakyReLU(negative_slope=0.01)
    (12): Conv2d(512, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (decoder): Sequential(
    (0): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Upsample(scale_factor=2.0,

In [18]:
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

In [19]:
distances = []
autoencoder.decoder.eval()
autoencoder.to('cuda:0')

for X_origin, _ in train_loader:   
    X_aug = augmentation(X_origin)
    X_origin, X_aug = X_origin.to('cuda:0'), X_aug.to('cuda:0')
    dist = []
    for _ in range(6):
        points_origin = autoencoder.encoder(X_origin)
        points_aug = autoencoder.encoder(X_aug)

        dist.append(((points_origin - points_aug) ** 2).sum().item() / 32)

        out_origin, out_aug = autoencoder.decoder(points_origin), autoencoder.decoder(points_aug)

        loss = F.mse_loss(X_origin, out_aug) + F.mse_loss(X_aug, out_origin)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    distances.append(dist)

In [23]:
results = fit_FC(autoencoder, (train_loader, test_loader), args)

                                                                     

In [24]:
results

[0.32374471544715444,
 0.34065447154471545,
 0.3519658536585365,
 0.35992032520325207,
 0.36841626016260165,
 0.3730918699186992,
 0.37949837398373987,
 0.38428699186991866,
 0.39053739837398377,
 0.38910487804878047,
 0.39213089430894305,
 0.3969536585365853,
 0.399950406504065,
 0.39864878048780483,
 0.40145691056910565,
 0.40758373983739843,
 0.4075666666666667,
 0.40648130081300815,
 0.41184552845528455,
 0.4137276422764228]