In [None]:
import os
import gc
import sys
import cv2
import glob
import math
import time
import tqdm
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')

from accelerate import Accelerator

from functools import partial
from sklearn.model_selection import StratifiedKFold

import torch
import torch.nn as nn
import torch.optim as optim
import torch.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import datasets
import torchvision.transforms as transforms

import albumentations as A 
from albumentations.pytorch.transforms import ToTensorV2

from transformers import get_cosine_schedule_with_warmup, AdamW

from colorama import Fore, Back, Style
r_ = Fore.RED
b_ = Fore.BLUE
c_ = Fore.CYAN
g_ = Fore.GREEN
y_ = Fore.YELLOW
m_ = Fore.MAGENTA
sr_ = Style.RESET_ALL

In [None]:


config = {'lr':1e-3,
          'wd':1e-2,
          'bs':256,
          'img_size':128,
          'epochs':100,
          'seed':1000}

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONASSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(seed=config['seed'])

train_paths = np.random.choice(glob.glob('../input/imagenetmini-1000/imagenet-mini/train/**/*.JPEG'),10000)
valid_paths = np.random.choice(glob.glob('../input/imagenetmini-1000/imagenet-mini/val/**/*.JPEG'),1000)



In [None]:


def get_train_transforms():
    return A.Compose(
        [
            A.Resize(config['img_size'],config['img_size'],always_apply=True),
            A.Normalize(),
            ToTensorV2(p=1.0)
        ])



In [None]:
class ImageNetDataset(Dataset):
    def __init__(self,paths,augmentations):
        self.paths = paths
        self.augmentations = augmentations
    
    def __getitem__(self,idx):
        path = self.paths[idx]
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.augmentations:
            augmented = self.augmentations(image=image)
            image = augmented['image']
        
        return image
    
    def __len__(self):
        return len(self.paths)

In [5]:


test_dataset = ImageNetDataset(valid_paths,augmentations=get_train_transforms())
test_dl = DataLoader(test_dataset,batch_size=16,shuffle=False,num_workers=4)

dataiter = iter(test_dl)
sample = dataiter.next()

img = torchvision.utils.make_grid(sample).permute(1,2,0).numpy()
plt.figure(figsize=(15,15))
plt.imshow(img)



In [None]:


class Model(nn.Module):

    def __init__(self,latent_dim=512):
        super().__init__()

        self.latent_dim = latent_dim
        self.shape = 32

        #encode
        self.conv1 = nn.Conv2d(3,32,kernel_size=3,stride=2)
        self.conv2 = nn.Conv2d(32,64,kernel_size=3,stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64*(self.shape-1)**2,2*self.latent_dim)
        self.relu = nn.ReLU()
        self.scale = nn.Parameter(torch.tensor([0.0]))

        #decode
        self.fc2 = nn.Linear(self.latent_dim,(self.shape**2) *32)
        self.conv3 = nn.ConvTranspose2d(32,64,kernel_size=2,stride=2)
        self.conv4 = nn.ConvTranspose2d(64,32,kernel_size=2,stride=2)
        self.conv5 = nn.ConvTranspose2d(32,3,kernel_size=1,stride=1)


    def encode(self,x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
#         print(x.shape)
        x = self.relu(self.flatten(x))
        x = self.fc1(x)
        mean,logvar = torch.split(x,self.latent_dim,dim=1)
        return mean, logvar

    def decode(self,eps):
        x = self.relu(self.fc2(eps))
        x = torch.reshape(x,(x.shape[0],32,self.shape,self.shape))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.conv5(x)
        return x

    def reparamatrize(self,mean,std):
        q = torch.distributions.Normal(mean,std)
        return q.rsample()

    def kl_loss(self,z,mean,std):
        p = torch.distributions.Normal(torch.zeros_like(mean),torch.ones_like(std))
        q = torch.distributions.Normal(mean,torch.exp(std/2))

        log_pz = p.log_prob(z)
        log_qzx = q.log_prob(z)

        kl_loss = (log_qzx - log_pz)
        kl_loss = kl_loss.sum(-1)
        return kl_loss

    def gaussian_likelihood(self,inputs,outputs,scale):
        dist = torch.distributions.Normal(outputs,torch.exp(scale))
        log_pxz = dist.log_prob(inputs)
        return log_pxz.sum(dim=(1,2,3))

    def loss_fn(self,inputs,outputs,z,mean,std):
        kl_loss = self.kl_loss(z,mean,std)
        rec_loss = self.gaussian_likelihood(inputs,outputs,self.scale)

        return torch.mean(kl_loss - rec_loss)

    def forward(self,inputs):
        mean,logvar = self.encode(inputs)
        std = torch.exp(logvar/2)
        z = self.reparamatrize(mean,std)
        outputs = self.decode(z)

        loss = self.loss_fn(inputs,outputs,z,mean,std)
        return loss,(outputs,z,mean,std)



In [None]:


def run():
        
    def evaluate(model,valid_loader):
        model.eval()
        valid_loss = 0
        with torch.no_grad():
            for i, inputs in enumerate(valid_loader):
                loss,_ = model(inputs)
                valid_loss += loss.item()

        valid_loss /= len(valid_loader)
        return valid_loss
        
    def train_and_evaluate_loop(train_loader,valid_loader,model,optimizer,
                                epoch,best_loss,lr_scheduler=None):
        train_loss = 0
        for i, inputs in enumerate(train_loader):
            optimizer.zero_grad()
            model.train()
            loss,_ = model(inputs)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            if lr_scheduler:
                lr_scheduler.step()
        
        train_loss /= len(train_loader)
        valid_loss = evaluate(model,valid_loader) 

        if valid_loss <= best_loss:
            print(f"Epoch:{epoch} |Train Loss:{train_loss}|Valid Loss:{valid_loss}")
            print(f"{g_}Loss Decreased from {best_loss} to {valid_loss}{sr_}")

            best_loss = valid_loss
            torch.save(model.state_dict(),'./imagenet_vae_model.bin')
                    
        return best_loss
        
    accelerator = Accelerator()
    print(f"{accelerator.device} is used")

    model = Model()
    
    ## train
    train_dataset = ImageNetDataset(train_paths,augmentations=get_train_transforms())
    train_dl = DataLoader(train_dataset,batch_size=config['bs'],shuffle=True,num_workers=4)
    
    
    #valid
    valid_dataset = ImageNetDataset(valid_paths,augmentations=get_train_transforms())
    valid_dl = DataLoader(valid_dataset,batch_size=config['bs'],shuffle=False,num_workers=4)
    
      
    optimizer = AdamW(model.parameters(),lr=config['lr'],weight_decay=config['wd'])
    lr_scheduler = get_cosine_schedule_with_warmup(optimizer,num_warmup_steps=0,num_training_steps= config['epochs'] * len(train_dl))

    model,train_dl,valid_dl,optimizer,lr_scheduler = accelerator.prepare(model,train_dl,valid_dl,optimizer,lr_scheduler)

    best_loss = 9999999
    start_time = time.time()
    for epoch in range(config["epochs"]):
        print(f"Epoch Started:{epoch}")
        best_loss = train_and_evaluate_loop(train_dl,valid_dl,model,optimizer,epoch,best_loss,lr_scheduler)
        
        end_time = time.time()
        print(f"{m_}Time taken by epoch {epoch} is {end_time-start_time:.2f}s{sr_}")
        start_time = end_time
        
    return best_loss



In [None]:


run()

