In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.models as visionmodels
import torchvision.transforms as transforms
import torch.utils.data.dataloader as DataLoader
import torch.utils.data.dataset as Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable,grad
from torchvision import datasets
from collections import *
import os
from torchvision.io import read_image
from datetime import datetime
import glob
from zipfile import ZipFile


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True

FLAGS = {} 
FLAGS['datadir'] = 'data/cats'
FLAGS['labeldir'] = '/content/gdrive/MyDrive/Colab Notebooks/data/cats_small.csv'
FLAGS['batch_size'] = 4
FLAGS['learning_rate'] = .001
FLAGS['num_epochs'] = 2000
FLAGS['im_channels'] = 3
FLAGS['noise_scale_factor'] = .01

FLAGS['style_weight'] = 1.
FLAGS['content_weight'] = 1. 
FLAGS['tv_weight'] = 1e-4
FLAGS['pixel_weight'] = 10. 

FLAGS['im_size'] = 256
FLAGS['val_image_path'] = '/content/gdrive/MyDrive/Colab Notebooks/Style Transfer/data/cats/val_cat.jpg' 
FLAGS['style_image_path'] = '/content/gdrive/MyDrive/Colab Notebooks/Style Transfer/data/monet/07.jpg' 
FLAGS['val_size_mult'] = 2
FLAGS['output_path'] = 'outputs/Johnson_NST_IN'
FLAGS['model_path'] = 'models/Johnson_NST_IN'
FLAGS['output_fname'] = 'test_output'
FLAGS['model_fname'] = 'model'
FLAGS['home_dir'] = '/content/gdrive/MyDrive/Colab Notebooks/Style Transfer'


zip_loc = '/content/gdrive/MyDrive/Colab Notebooks/data/cats_small.zip'
with ZipFile(zip_loc, 'r') as zf: 
  zf.extractall('data/cats')


  

#Output and Saving Methods

In [None]:
def plot_image(test_image, val_image, save_result: bool = True): 
  fig = plt.figure(figsize = (5,5), dpi = 200) 
  ax = fig.subplots(2)
  ax[0].imshow(test_image[0].detach().cpu().permute(1,2,0))
  ax[1].imshow(val_image[0].detach().cpu().permute(1,2,0))
  if save_result: 
    os.chdir(FLAGS['home_dir'])
    try: 
      os.chdir(FLAGS['output_path']) 
    except: 
      for dir in FLAGS['output_path'].split('/'):
        try:
          os.chdir(dir)
        except: 
          os.mkdir(dir) 
          os.chdir(dir)
    fname = FLAGS['output_fname'] 
    num_imgs = len(glob.glob('*.png')) 
    fname += '_'+str(num_imgs)+'.png'
    plt.savefig(fname)
    plt.show(block = False)
    os.chdir('/content/') 
  
def save_model(model): 
  os.chdir(FLAGS['home_dir'])
  try: 
    os.chdir(FLAGS['model_pathv']) 
  except: 
    for dir in FLAGS['model_path'].split('/'):
      try:
        os.chdir(dir)
      except: 
        os.mkdir(dir) 
        os.chdir(dir)
  fname = FLAGS['model_fname']
  num_models = len(glob.glob('*.model')) 
  fname += '_'+str(num_models)+'.model'
  torch.save(model.state_dict(), fname)
  os.chdir('/content/')

#Models

In [None]:
'''
  single convolution layers
'''
class JohnsonConvLayer(nn.Module): 
  def __init__(self, in_channels, out_channels,kernel: int = 3, stride : int = 1, padding:int=1, output_padding:int = 0, padding_mode:str = 'reflect',bias : bool = True, upsample: bool = False, downsample: bool = False,relu_inplace: bool = True): 
    super().__init__() 
    layers = [] 
    if upsample:
      assert stride > 1
      layers += [nn.ConvTranspose2d(in_channels, out_channels,kernel_size = kernel, stride = stride, padding = padding, output_padding= output_padding,bias = bias, padding_mode = 'zeros')]
    if downsample:     
      assert stride > 1
      layers += [nn.Conv2d(in_channels, out_channels, kernel_size = kernel, stride = stride, padding = padding, padding_mode = padding_mode, bias = bias)]
    if (not upsample) and (not downsample): 
      assert stride == 1
      layers += [nn.Conv2d(in_channels, out_channels, kernel_size = kernel, stride = stride, padding = padding, padding_mode = padding_mode, bias = bias)]
    layers += [nn.InstanceNorm2d(out_channels, affine = True)]
    layers += [nn.LeakyReLU(.2,inplace = relu_inplace)]
    assert len(layers) == 3
    self.conv_layer = nn.Sequential(*layers)  
  
  def forward(self,input): 
    return self.conv_layer(input)

'''
  residual layer
'''
class GrossResidualLayer(nn.Module): 
  def __init__(self, in_channels, out_channels,kernel: int = 1, stride : int = 1, padding:int=1, padding_mode:str = 'reflect',bias : bool = True): 
    super().__init__() 
    self.conv_layer = nn.Sequential(
        nn.Conv2d(in_channels,out_channels, kernel_size = kernel, stride = stride, padding = padding, padding_mode = padding_mode, bias = bias), 
        nn.InstanceNorm2d(out_channels, affine = True), 
        nn.LeakyReLU(.2,inplace = True), 
        nn.Conv2d(out_channels,out_channels, kernel_size = kernel, stride = stride, padding = padding, padding_mode = padding_mode, bias = bias), 
        nn.InstanceNorm2d(out_channels, affine = True)
    )

  def forward(self,input): 
    return input + self.conv_layer(input) 

'''
  main generator network
'''

class JohnsonGeneratorNetwork(nn.Module): 
  def __init__(self): 
    super().__init__() 
    layers = [] 
    layers += [JohnsonConvLayer(3,32,kernel = 9, stride = 1, padding = 4)] 
    layers += [JohnsonConvLayer(32,64,kernel = 3, stride = 2, downsample = True)]
    layers += [JohnsonConvLayer(64,128,kernel  = 3, stride = 2, downsample = True)]

    for i in range(5): 
      layers += [GrossResidualLayer(128,128, kernel = 3, stride = 1)]
    
    layers += [JohnsonConvLayer(128,64,kernel = 3, stride = 2, padding = 0, output_padding = 1, upsample = True)]
    layers += [JohnsonConvLayer(64,32,kernel = 3, stride = 2, padding = 0, output_padding = 1, upsample = True)]
    layers += [JohnsonConvLayer(32,3,kernel = 9, stride = 1, padding = 1,relu_inplace= False)]

    self.main = nn.Sequential(*layers)

  def forward(self,input): 
    return self.main(input) 

class Normalization(nn.Module): 
  def __init__(self,mean,std): 
    super().__init__() 
    self.mean = mean.view(1,-1,1,1)
    self.std = std.view(1,-1,1,1) 
  def forward(self,x):  
    return (x-self.mean)/self.std 

class FeatureNetwork(nn.Module): 
  def __init__(self,vgg_features): 
    super().__init__()
    self.features = [feature.to(device) for feature in vgg_features]
    #feature layers of a vgg19_bn network
    #self.content_layers = [32]
    #self.style_layers = [2,9,26,23,30,42]
    
    #layers of vgg19 net
    # self.content_layers = [22] 
    # self.style_layers = [1,6,11,20,29]

    #layers of vgg16
    self.content_layers = [8] 
    self.style_layers = [3,8,15,18,24]
    mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
    std = torch.tensor([0.229, 0.224, 0.225]).to(device)
    self.norm = Normalization(mean,std)
  def forward(self,input): 
    input = self.norm(input)
    content = []
    styles = [] 
    for i,module in enumerate(self.features): 
      input = module(input) 
      if i in self.content_layers: 
        content.append(input)
      elif i in self.style_layers: 
        styles.append(input)
    return content, styles

# Losses

In [None]:
def gram_matrix(A): 
  N,C, h, w = A.shape
  if N == 1:
    A = A.view(C,-1)
    G = torch.mm(A,A.t()).view(1,C,-1) 
  else: 
    A = A.view(N,C,-1)
    G = torch.bmm(A,A.transpose(1,2))
  return G.div(C*h*w) #returns CxC normalized gram matrix

def gramMSELoss(input,target): 
  #assuming target is already a gram matrix 
  G = gram_matrix(input)
  return F.mse_loss(G,target) 

def gramFrobLoss(input,target): 
  G = gram_matrix(input) 
  return torch.linalg.norm(G-target,'fro',(1,2)).sum()

def styleLoss(styles,target_styles,mode :str = 'fro'): 
  style_loss = 0 
  for style,target in zip(styles,target_styles): 
    if mode == 'fro': 
      style_loss += gramFrobLoss(style,target) 
    elif mode == 'mse': 
      style_loss += gramMSELoss(style,target) 
  return style_loss

'''
  compute normalized MSELoss
'''
def normMSELoss(input,target): 
  _, c, h, w = input.shape
  return (1/(c*h*w))*torch.linalg.norm(input-target,2,(2,3)).sum()

'''
  compute content loss 
'''
def contentLoss(contents, target_content): 
  content_loss = 0 
  for content,target in zip(contents, target_content): 
    content_loss+= normMSELoss(content,target) 
  return content_loss


'''
  pixel loss and total variation loss from johnson
'''
def pixelLoss(gen_image, target_image): 
  assert gen_image.shape == target_image.shape
  n,c,h,w = gen_image.shape
  return (1/(n*c*h*w))*torch.linalg.norm(gen_image.view(n,c,-1)-target_image.view(n,c,-1),2,2).sum()

def totalVariationLoss(image): 
  n, c, h, w = image.shape
  z = F.pad(image,(0,1,0,1))
  tv_reg = ((z[:,:,1:,:-1]-z[:,:,:-1,:-1]).pow(2) + (z[:,:,:-1,1:] - z[:,:,:-1,:-1]).pow(2)).sqrt().sum()
  return tv_reg/(n*c*h*w) 



#DataLoader and weight initialization

In [None]:
class customDataset(Dataset.Dataset):
    def __init__(self, annotations_file, img_dir, transform=None,target_transform = None):
        '''
            we just want an unlabelled image dataset
        '''
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return (len(self.img_labels))

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        image = read_image(img_path)/255

        if self.transform:
            image = self.transform(image)
        return image

def _initialize_dataset(data_path = None, label_path = None, data_transform = None,target_transform = None):
    dataset = customDataset(label_path, data_path, transform=data_transform,target_transform = target_transform)
    training_set = torch.utils.data.DataLoader(
         dataset, batch_size=FLAGS['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
    return training_set,dataset

def init_weights(m):
  if isinstance(m,nn.Conv2d):
    torch.nn.init.xavier_normal_(m.weight)
    if m.bias.data is not None: 
      m.bias.data.fill_(0.)
  elif isinstance(m,nn.ConvTranspose2d): 
    torch.nn.init.xavier_normal_(m.weight)
    if m.bias.data is not None: 
      m.bias.data.fill_(0.)

#Training Loop

In [None]:
def train(feature_net, 
          generator, 
          optimizer, 
          scheduler, 
          loader, 
          val_image, 
          style_image
          ):
  for p in feature_net.parameters(): 
    p.requires_grad = False
  style_input = Variable(style_image,requires_grad = False).to(device)
  val_input = Variable(val_image, requires_grad = False).to(device)
  _,target_styles = feature_net(style_input)
  fixed_noise = torch.rand(val_input.shape[0],FLAGS['im_channels'], FLAGS['im_size']*FLAGS['val_size_mult'], FLAGS['im_size']*FLAGS['val_size_mult'],device = device)*FLAGS['noise_scale_factor']
  val_input += fixed_noise 
  val_input.clip_(0,1)
  for i,A in enumerate(target_styles): 
    target_styles[i] = gram_matrix(A.detach()).tile(FLAGS['batch_size'],1,1)

  for epoch in range(FLAGS['num_epochs']): 
    for p in generator.parameters(): 
      p.requires_grad = True
    print('processing epoch {}/{}'.format(epoch+1, FLAGS['num_epochs']))
    for i,image in enumerate(loader): 
      batch_size = len(image)
      if batch_size != FLAGS['batch_size']: 
        break
     
      input = Variable(image).to(device)  
      nz = Variable(torch.rand(FLAGS['batch_size'],FLAGS['im_channels'],FLAGS['im_size'],FLAGS['im_size'],device = device),requires_grad = False)*FLAGS['noise_scale_factor']
      input += nz 
      input.clip_(0,1)
      optimizer.zero_grad() 
      '''
          forward pass
      '''
      gen_output = generator(input)
      
      gen_output.clip_(0,1)
      '''
        generate test style and content 
      '''
      contents,styles = feature_net(gen_output) 
      '''
        generate actual content
      '''
      target_content, _  = feature_net(input)
      loss = 0 

      loss += styleLoss(styles, target_styles)*FLAGS['style_weight']
      loss += contentLoss(contents, target_content) *FLAGS['content_weight']
      loss += pixelLoss(gen_output, input) *FLAGS['pixel_weight']
      loss += totalVariationLoss(gen_output) *FLAGS['tv_weight']
      loss.backward()
      
      optimizer.step() 
    scheduler.step()
    if (epoch+1) % 10 == 0: 
      print('stats:{}'.format(loss.detach()))
    if (epoch+1) % 10 == 0: 
      for p in generator.parameters(): 
        p.requires_grad = False
      test_out = generator(val_input)
      plot_image(gen_output, test_out)

    if (epoch+1)%10 == 0: 
      save_model(generator)


In [None]:
if __name__ == '__main__': 
  os.chdir('/content/')
  %ls 
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  vgg_net = visionmodels.vgg16(pretrained=True)
  vgg_net.eval()
  feature_net = FeatureNetwork(list(vgg_net.features)).to(device) 
  generator = JohnsonGeneratorNetwork().to(device)
  generator.apply(init_weights)
  optimizer = optim.Adam(generator.parameters(), lr = FLAGS['learning_rate'], betas = (.5,.999))
  #optimizer = optim.LBFGS(generator.parameters()) 
  def lmbda(epoch): 
    if epoch == 1000: 
      return .1 
    if epoch > 1000: 
      if epoch %200 == 0: 
        return .1
      else: 
        return 1 
    else: 
      return 1
  scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda = lmbda)

  im_size = FLAGS['im_size']
  test_mult = FLAGS['val_size_mult']

  transform = transforms.Compose([ transforms.Resize(im_size), transforms.CenterCrop((im_size,im_size))])
  val_transform = transforms.Compose([transforms.Resize(im_size*test_mult),transforms.CenterCrop((test_mult*im_size,test_mult*im_size))])

  train_loader,dataset = _initialize_dataset(data_path = FLAGS['datadir'], label_path = FLAGS['labeldir'],\
                                data_transform = transform)
  
  val_image = val_transform(read_image(FLAGS['val_image_path'])/255).unsqueeze(0)
  style_image = transform(read_image(FLAGS['style_image_path'])/255).unsqueeze(0)

  output = train(feature_net,       #network for extracting features
                  generator, 
                  optimizer, 
                  scheduler, 
                  train_loader,
                  val_image,                    #content target image
                  style_image                      #style target image
                  )