In [1]:
import torch 
from torch import nn
from torchvision import transforms,models
from torchvision.datasets import ImageFolder
from PIL import Image
from ResNet import ResNet 
from loss_network import vgg
from disc import Discriminator

In [2]:
dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
from google.colab import drive 
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
class Model(object):

    def __init__(self, image_size=256):
        self.image_size = image_size
        self.dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.transform_net = ResNet().to(self.dev)
        self.loss_net = vgg().to(self.dev)
        self.disc = Discriminator(self.image_size).to(self.dev)
        self.loss_net.eval()


    def load_dataset(self, dataset_path, style_img_path, batch_size=5):

        self.batch_size = batch_size
        t = transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.CenterCrop(self.image_size),
            transforms.ToTensor(),
            transforms.Lambda(lambda x : x.mul(255))
        ])

        data_folder = ImageFolder(dataset_path, transform=t)
        self.contentloader = torch.utils.data.DataLoader(data_folder, batch_size=batch_size, shuffle=True)

        stylefolder = ImageFolder(style_img_path, transform=t)
        self.styleloader = torch.utils.data.DataLoader(stylefolder, batch_size=1, shuffle=True)


    def vgg_normalize(self, x):
        normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                         std=(0.229, 0.224, 0.225))
        x = x.div_(255.0)
        out = normalize(x)
        return out


    def gram_matrix(self, x):
        (batch, ch, h, w) = x.size() 
        img = x.view(batch, ch, h*w)
        img_transposed = img.transpose(1,2)
        gram_matrix = img.bmm(img_transposed) / (ch * h * w)
        return gram_matrix


    def load_model(self, model_path):
        if model_path != None : 
            self.transform_net.load_state_dict(torch.load(model_path, map_location=self.dev))


    def train(self, epochs, style_weight, content_weight, disc_weight, lr, save_path=None, model_path=None, show_every=20, save_every=100):

        self.mseloss = torch.nn.MSELoss()
        self.bceloss = torch.nn.BCELoss()
        self.opt = torch.optim.Adam(self.transform_net.parameters(), lr)
        self.disc_opt = torch.optim.Adam(self.disc.parameters(), lr=0.0001)
        self.load_model(model_path)

        style_activation_list = []
        style_images_list = []

        for style_img, _ in self.styleloader : 
            style_images_list.append(style_img)
            style_img = style_img.repeat(self.batch_size, 1, 1, 1)
            style_activations = self.loss_net(self.vgg_normalize(style_img))
            style_grams = [self.gram_matrix(y) for y in style_activations]
            style_activation_list.append(style_grams)


        for epoch in range(epochs):
            self.transform_net.train()
            print('EPOCH : ', epoch)

            for batch_id, (img, _) in enumerate(self.contentloader):

                for style_grams, style_image in zip(style_activation_list, style_images_list) : 

                    n_batch = len(img)
                    
                    img = img.to(self.dev)
                    yhat = self.transform_net(img)

                    disc_real_output = disc.forward(style_image)
                    disc_fake_output = disc.forward(yhat)

                    disc_real_loss = self.bceloss(disc_real_output, torch.ones_like(disc_real_output))
                    disc_fake_loss = self.bceloss(disc_fake_output, torch.zeros_like(disc_fake_output))
                    disc_loss = (disc_fake_loss + disc_real_loss)/2

                    gen_real_loss = self.bceloss(disc_real_output, torch.zeros_like(disc_real_output))
                    gen_fake_loss = self.bceloss(disc_fake_output, torch.ones_like(disc_fake_output))
                    gen_loss = (gen_fake_loss + gen_real_loss)/2

                    gen_loss *= disc_weight

                    self.disc_opt.zero_grad()
                    disc_loss.backward()
                    self.disc_opt.step()

                    yhat_normalized = self.vgg_normalize(yhat)
                    content_normalized = self.vgg_normalize(img)

                    yhat_activations = self.loss_net(yhat_normalized)
                    content_activations = self.loss_net(content_normalized)

                    content_loss = content_weight * self.mseloss(yhat_activations[0], content_activations[0])

                    style_loss = 0
                    for yhat, style_gram in zip(yhat_activations, style_grams):
                        yhat_gram = self.gram_matrix(yhat)
                        style_loss += self.mseloss(yhat_gram, style_gram[:n_batch , : , :])
                    style_loss *= style_weight

                    total_loss = style_loss + content_loss + gen_loss

                    self.opt.zero_grad()
                    total_loss.backward(retain_graph=True)
                    self.opt.step()

                if batch_id % show_every == 0 : 
                    print(f'Batch : {batch_id} | Content loss : {content_loss.item()} | Style loss : {style_loss.item()} | Total loss : {total_loss.item()}')

                if batch_id % save_every == 0 : 
                    if save_path != None : 
                        torch.save(self.transform_net.state_dict(), save_path)


In [5]:
dataset_path = '/content/drive/MyDrive/EE/coco'
style_img_path = '/content/drive/MyDrive/EE/styles/muse.jpg'
EPOCHS = 10 
CONTENT_WEIGHT = 1e5
STYLE_WEIGHT = 1e9
LR = 0.001
SAVE_PATH = '/content/drive/MyDrive/EE/model.pt'
MODEL_PATH = None

model = Model()
model.load_dataset(dataset_path=dataset_path, style_img_path=style_img_path)
model.train(model_path=MODEL_PATH, epochs=EPOCHS, style_weight=STYLE_WEIGHT, content_weight=CONTENT_WEIGHT, lr=LR, save_path=SAVE_PATH, show_every=10)

EPOCH :  0
Batch : 0 | Content loss : 178605.140625 | Style loss : 14357277.0 | Total loss : 14535882.0
Batch : 10 | Content loss : 179770.75 | Style loss : 11208544.0 | Total loss : 11388315.0
Batch : 20 | Content loss : 169421.46875 | Style loss : 9406555.0 | Total loss : 9575976.0
Batch : 30 | Content loss : 215300.671875 | Style loss : 8521484.0 | Total loss : 8736785.0
Batch : 40 | Content loss : 221098.875 | Style loss : 7773263.5 | Total loss : 7994362.5
Batch : 50 | Content loss : 243177.3125 | Style loss : 7166873.0 | Total loss : 7410050.5
Batch : 60 | Content loss : 242279.796875 | Style loss : 6566985.5 | Total loss : 6809265.5
Batch : 70 | Content loss : 289894.90625 | Style loss : 5574581.5 | Total loss : 5864476.5
Batch : 80 | Content loss : 318673.8125 | Style loss : 4879651.0 | Total loss : 5198325.0
Batch : 90 | Content loss : 279572.84375 | Style loss : 4338081.0 | Total loss : 4617654.0
Batch : 100 | Content loss : 332152.4375 | Style loss : 3898925.75 | Total loss 

KeyboardInterrupt: ignored