In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.models as visionmodels
import torchvision.transforms as transforms
import torch.utils.data.dataloader as DataLoader
import torch.utils.data.dataset as Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable,grad
from torchvision import datasets
from collections import *
import os
from torchvision.io import read_image
from datetime import datetime
import glob
from zipfile import ZipFile


In [9]:
from google.colab import drive
drive.mount('/content/gdrive')

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True

FLAGS = {} 
FLAGS['datadir'] = 'data/cats'
FLAGS['labeldir'] = '/content/gdrive/MyDrive/Colab Notebooks/data/cats_train.csv'
FLAGS['batch_size'] = 2
FLAGS['learning_rate'] = .0005
FLAGS['betas'] = (.5,.99)
FLAGS['num_epochs'] = 2000
FLAGS['im_channels'] = 3

FLAGS['style_weight'] = 1.
FLAGS['content_weight'] = 1. 

FLAGS['im_size'] = 256
FLAGS['val_image_path'] = '/content/gdrive/MyDrive/Colab Notebooks/Style Transfer/data/cats/val_cat.jpg' 
FLAGS['style_image_path'] = '/content/gdrive/MyDrive/Colab Notebooks/Style Transfer/data/monet/07.jpg' 
FLAGS['val_size_mult'] = 2
FLAGS['output_path'] = 'outputs/ADAIN'
FLAGS['model_path'] = 'models/ADAIN'
FLAGS['output_fname'] = 'test_output'
FLAGS['model_fname'] = 'model'
FLAGS['home_dir'] = '/content/gdrive/MyDrive/Colab Notebooks/Style Transfer'


zip_loc = '/content/gdrive/MyDrive/Colab Notebooks/data/cats.zip'
with ZipFile(zip_loc, 'r') as zf: 
  zf.extractall('data/cats')
zip_loc_monet = '/content/gdrive/MyDrive/Colab Notebooks/data/monet_jpg.zip'
with ZipFile(zip_loc_monet, 'r') as zf: 
  zf.extractall('data/monet') 


Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


#Output and Saving Methods

In [3]:
def plot_image(test_image, val_image, save_result: bool = True): 
  fig = plt.figure(figsize = (5,5), dpi = 200) 
  ax = fig.subplots(2)
  ax[0].imshow(test_image[0].detach().cpu().permute(1,2,0))
  ax[1].imshow(val_image[0].detach().cpu().permute(1,2,0))
  if save_result: 
    os.chdir(FLAGS['home_dir'])
    try: 
      os.chdir(FLAGS['output_path']) 
    except: 
      for dir in FLAGS['output_path'].split('/'):
        try:
          os.chdir(dir)
        except: 
          os.mkdir(dir) 
          os.chdir(dir)
    fname = FLAGS['output_fname'] 
    num_imgs = len(glob.glob('*.png')) 
    fname += '_'+str(num_imgs)+'.png'
    plt.savefig(fname)
    plt.show(block = False)
    os.chdir('/content/') 
  
def save_model(model): 
  os.chdir(FLAGS['home_dir'])
  try: 
    os.chdir(FLAGS['model_pathv']) 
  except: 
    for dir in FLAGS['model_path'].split('/'):
      try:
        os.chdir(dir)
      except: 
        os.mkdir(dir) 
        os.chdir(dir)
  fname = FLAGS['model_fname']
  num_models = len(glob.glob('*.model')) 
  fname += '_'+str(num_models)+'.model'
  torch.save(model.state_dict(), fname)
  os.chdir('/content/')

#Models

In [13]:
class Normalization(nn.Module): 
  def __init__(self,mean,std): 
    super().__init__() 
    self.mean = mean.view(1,-1,1,1)
    self.std = std.view(1,-1,1,1) 
  def forward(self,x):  
    return (x-self.mean)/self.std 

'''
  doens't use feature matched content
'''
class FeatureNetwork(nn.Module): 
  def __init__(self,vgg_features): 
    super().__init__()
    
    #feature layers of a vgg19_bn network
    #self.content_layers = [32]
    #self.style_layers = [2,9,26,23,30,42]
    
    #layers of vgg19 net
    #self.content_layers = [22] 
    self.style_layers = [1,6,11,20]
    max_layer = self.style_layers[-1]
    #layers of vgg16
    # self.content_layers = [8] 
    # self.style_layers = [3,8,15,18,24]

    self.features = [feature.to(device) for i,feature in enumerate(vgg_features) if i <=max_layer]

    mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
    std = torch.tensor([0.229, 0.224, 0.225]).to(device)
    self.norm = Normalization(mean,std)
  def forward(self,input): 
    input = self.norm(input)
    styles = [] 
    for i,module in enumerate(self.features): 
      input = module(input) 
      if i in self.style_layers: 
        styles.append(input)
    return  styles



'''
  AdaIN layer
  Takes in a layers and aligns a target images styles to a style images mean and variance
  need to run this for each style
'''
class AdaIN(nn.Module): 
  def __init__(self): 
    super().__init__()  
    
  #given two feature maps --> N x C x K x K 
  #want to compute mean and var
  def forward(self, source_style, target_mean, target_dev): 
    n,c,_,_ = source_style.shape
    mu_x = self.compute_mean(source_style)
    sigma_x = self.compute_var(source_style, mu_x)
    return (target_dev.view(n,c,1,1)*(source_style - mu_x)/sigma_x) + target_mean.view(n,c,1,1) 
  
  #compute the instance mean of a feature tensor
  #returns an N x C x 1 x 1 tensor
  def compute_mean(self, style):
    n,c,h,w = style.shape 
    mu = style.sum(dim = (2,3)) 
    return mu.div(h*w).view(n,c,1,1) #---> normalize mean 
  
  #returns an N x C x 1 x 1 tensor
  def compute_var(self, style, mean, eps = 1e-8): 
    n,c,h,w = style.shape 
    sigma = style.sum(dim = (2,3)).div(h*w) + eps
    return sigma.sqrt().view(n,c,1,1)

'''
  applies adain to each layer to normalize
  this output list now has transformed features and means and stds of origianl output
'''
class AdaModule(nn.Module): 
  def __init__(self): 
    super().__init__() 
    self.adain = AdaIN()
  def forward(self, styles, target_means, target_devs): 
    out = [] 
    for style, target_mean, target_dev in zip(styles, target_means, target_devs): 
      moduled_styles =self.adain(style,target_mean, target_dev)
      out += [moduled_styles]
    return out
'''
  decoder, mimics vgg network layout, inverted, with no normalization
'''
class VggConvBlock(nn.Module): 
  def __init__(self, in_channels, out_channels,num_layers = 4, kernel = 3, stride =1 , padding = 1, padding_mode = 'reflect', bias = True): 
    super().__init__() 
    self.conv_layer = nn.Sequential(
        nn.Conv2d(in_channels, in_channels, kernel_size= kernel, stride = stride, padding = padding, padding_mode = padding_mode, bias = bias), 
        nn.LeakyReLU(.2, inplace = False) 
    )
    self.final_layer = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride = stride, padding = padding, padding_mode=padding_mode, bias = bias), 
        nn.LeakyReLU(.2, inplace = False)
    )
    layers = [] 
    for i in range(num_layers-1): 
      layers += [self.conv_layer]
    layers += [self.final_layer]
    self.conv_block = nn.Sequential(*layers)

  def forward(self,input,style = None):
    if style is not None: 
      input = torch.cat((input,style), dim = 1)  
    return self.conv_block(input)

class RealTimeDecoder(nn.Module): 
  def __init__(self): 
    super().__init__() 
    self.module = nn.ModuleDict()
    upsample = nn.Upsample(scale_factor=2, mode = 'nearest')
    self.module['convblock1'] = VggConvBlock(512,256, num_layers = 4)
    self.module['convblock2'] = VggConvBlock(256*2,128, num_layers = 4)
    self.module['convblock3'] = VggConvBlock(128*2,64, num_layers = 2)
    self.module['convblock4'] = VggConvBlock(64*2,3, num_layers = 2)
    self.module['upsample'] = upsample

  '''
    input is a list of styles of lengh 4
  '''
  def forward(self,input): 
    out = self.module['convblock1'](input[0])
    out = self.module['upsample'](out)
    out = self.module['convblock2'](out,input[1])
    out = self.module['upsample'](out)
    out = self.module['convblock3'](out,input[2])
    out = self.module['upsample'](out)
    out = self.module['convblock4'](out,input[3])
    return out 
    

#Losses

In [14]:
'''
  Style Losses defined here do not rely on gram matrices but rather losses on statistics
'''
class StyleLoss(nn.Module): 
  def __init__(self): 
    super().__init__() 
    
  def forward(self,input, means, stds): 
    loss = 0
    for s, mean, std in zip(input, means,stds): 
      loss += torch.linalg.norm(self.compute_mean(s)-mean, 2, 1) + torch.linalg.norm(self.compute_var(s,self.compute_mean(s))-std, 2, 1)
    return loss.mean()

  def compute_mean(self, style):
    n,c,h,w = style.shape 
    mu = style.sum(dim = (2,3)) 
    return mu.div(h*w) #---> normalize mean 
  
  def compute_var(self, style, mean, eps = 1e-8): 
    n,c,h,w = style.shape 
    sigma = style.sum(dim =(2,3)).div(h*w) + eps
    return sigma.sqrt() 


'''
  compute normalized MSELoss
'''
def normMSELoss(input,target): 
  n, c, h, w = input.shape
  return (1/(c*h*w))*torch.linalg.norm(input.view(n,c,-1)-target.view(n,c,-1),2,2).sum()

'''
  compute content loss between modulated features against generated image features
'''
def contentLoss(contents, target_content): 
  content_loss = 0 
  for content,target in zip(contents, target_content): 
    content_loss+= normMSELoss(content,target) 
  return content_loss

'''
  pixel loss and total variation loss from johnson
'''
def pixelLoss(gen_image, target_image): 
  assert gen_image.shape == target_image.shape
  n,c,h,w = gen_image.shape
  return (1/(n*c*h*w))*torch.linalg.norm(gen_image.view(n,c,-1)-target_image.view(n,c,-1),2,2).sum()


#Data Loader and initialization 

In [15]:
class customDataset(Dataset.Dataset):
    def __init__(self, annotations_file, img_dir, transform=None,target_transform = None,index = 2):
        '''
            we just want an unlabelled image dataset
        '''
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.ind =index

    def __len__(self):
        return (len(self.img_labels))

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, self.ind])
        image = read_image(img_path)/255
        if self.transform:
            image = self.transform(image)
        return image

def _initialize_dataset(data_path = None, label_path = None, data_transform = None,target_transform = None,index = 2):
    dataset = customDataset(label_path, data_path, transform=data_transform,target_transform = target_transform,index = index)
    training_set = torch.utils.data.DataLoader(
         dataset, batch_size=FLAGS['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
    return training_set,dataset

def init_weights(m):
  if isinstance(m,nn.Conv2d):
    torch.nn.init.xavier_normal_(m.weight)
    if m.bias.data is not None: 
      m.bias.data.fill_(0.)
  elif isinstance(m,nn.ConvTranspose2d): 
    torch.nn.init.xavier_normal_(m.weight)
    if m.bias.data is not None: 
      m.bias.data.fill_(0.)

#Training method 

In [16]:
def train(feature_net, 
          generator, 
          ada_module,
          optimizer, 
          scheduler, 
          loader, 
          styleLoader,
          val_image, 
          style_image
          ):
  for p in feature_net.parameters(): 
    p.requires_grad = False
  style_input = Variable(style_image,requires_grad = False).to(device)
  val_input = Variable(val_image, requires_grad = False).to(device)
  val_styles = feature_net(style_input) 
  
  target_styles = feature_net(style_input)
   
  styleloss = StyleLoss() 
  
 
  '''
    create set of targetstyles
  '''
  num_target_styles_batches = 50
  means = [] 
  devs = [] 
  for i,style_image in enumerate(styleLoader): 
      input = Variable(style_image).to(device)
      target_styles = feature_net(input)
      if i >= num_target_styles_batches: 
        break
      m = [] 
      d = [] 
      for s in target_styles: 
        m +=[styleloss.compute_mean(s).detach().cpu()]
        d += [styleloss.compute_var(s,m[-1]).detach().cpu()] 
      means += [m] 
      devs += [d]
  r = np.random.randint(0,num_target_styles_batches-1)
  test_means = [m[0].cuda() for m in means[:][r]]
  test_devs = [d[0].cuda() for d in devs[:][r]]
  val_mod_styles = ada_module(val_styles, test_means, test_devs)[::-1]

  for epoch in range(FLAGS['num_epochs']): 
    for p in generator.parameters(): 
      p.requires_grad = True
    print('processing epoch {}/{}'.format(epoch+1, FLAGS['num_epochs']))
    generator.train()
    #generate random target styles
    
    for i,image in enumerate(loader): 
      batch_size = len(image)
      r = np.random.randint(0,num_target_styles_batches-1) 
      target_means = [m.cuda() for m in means[:][r]] 
      
      target_devs = [d.cuda() for d in devs[:][r]]
      if batch_size != FLAGS['batch_size']: 
        break
     
      input = Variable(image).to(device)  
      
      '''
          forward pass
          take images and pass through encoder, no gradient for this part
      '''
      with torch.no_grad(): 
        input_styles = feature_net(input)
        '''
          use input styles and align to target styles 
        '''
        modulated_styles = ada_module(input_styles, target_means, target_devs)[::-1]
      
      '''
        pass through the decoder network
      '''
      optimizer.zero_grad() 
      generated_output = generator(modulated_styles)
      '''
        compute styles after generation
      '''
      styles = feature_net(generated_output)
      #loss against the style image statistics
      #currently causes an error
      style_loss = styleloss(styles, target_means, target_devs)*FLAGS['style_weight']
      style_loss.backward(retain_graph = True)
      #loss against the feature maps
      content_loss = contentLoss(styles[::-1], modulated_styles) *FLAGS['content_weight']
      content_loss.backward()
      
      optimizer.step() 
      if (i+1)%10 == 0: 
        print('current style loss: {}, content loss: {}'.format(style_loss.item(),content_loss.item()))

    scheduler.step()
     
    #print('stats:{}'.format(loss.detach()))
  
    generator.eval()
    for p in generator.parameters(): 
      p.requires_grad = False
    test_out = generator(val_mod_styles)
    plot_image(generated_output, test_out)

  
    save_model(generator)


In [None]:
if __name__ == '__main__': 
  os.chdir('/content/')
  %ls 
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  vgg_net = visionmodels.vgg16(pretrained=True)
  vgg_net.eval()
  feature_net = FeatureNetwork(list(vgg_net.features)).to(device) 
  ada_module = AdaModule().to(device)
  generator = RealTimeDecoder().to(device)
  generator.apply(init_weights)
  optimizer = optim.Adam(generator.parameters(), lr = 0.00005, betas = FLAGS['betas'])
  #optimizer = optim.LBFGS(generator.parameters()) 
  def lmbda(epoch): 
    if epoch == 1000: 
      return .1 
    if epoch > 1000: 
      if epoch %200 == 0: 
        return .1
      else: 
        return 1 
    else: 
      return 1
  scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda = lmbda)

  im_size = FLAGS['im_size']
  test_mult = FLAGS['val_size_mult']

  transform = transforms.Compose([ transforms.Resize(im_size*2), transforms.RandomCrop((im_size,im_size))])
  val_transform = transforms.Compose([transforms.Resize(im_size*test_mult*2),transforms.RandomCrop((test_mult*im_size,test_mult*im_size))])

  train_loader,dataset = _initialize_dataset(data_path = FLAGS['datadir'], label_path = FLAGS['labeldir'],\
                                data_transform = transform)
  
  monet_loader,dataset = _initialize_dataset(data_path = 'data/monet', label_path = FLAGS['labeldir'],\
                                data_transform = transform,index = 3)
  
  val_image = val_transform(read_image(FLAGS['val_image_path'])/255).unsqueeze(0)
  style_image = transform(read_image(FLAGS['style_image_path'])/255).unsqueeze(0)

  output = train(feature_net,       #network for extracting features
                  generator, 
                  ada_module, 
                  optimizer, 
                  scheduler, 
                  train_loader,
                  monet_loader, 
                  val_image,                    #content target image
                  style_image                      #style target image
                  )

In [None]:
%cd 
%ls

In [None]:
%cd /content/data
%ls

In [None]:
%cd cats 
%ls