In [1]:
import torch 
from torch import nn
from torchvision import transforms,models
from torchvision.datasets import ImageFolder
from collections import namedtuple
from PIL import Image

In [2]:
dev = torch.device('cuda:0')

In [3]:
from google.colab import drive 
drive.mount('/content/drive')

from ResNet import ResNet 
from loss_network import vgg

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
class Model(object):

    def __init__(self):
        self.dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.transform_net = ResNet().to(self.dev)
        self.loss_net = vgg().to(self.dev)
        self.loss_net.eval()


    def load_dataset(self, dataset_path, style_img_path, image_size=256, batch_size=5):

        self.batch_size = batch_size
        t = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Lambda(lambda x : x.mul(255))
        ])

        data_folder = ImageFolder(dataset_path, transform=t)
        self.loader = torch.utils.data.DataLoader(data_folder, batch_size=batch_size, shuffle=True)

        style_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])

        style = Image.open(style_img_path).convert('RGB').resize((image_size, image_size), Image.ANTIALIAS)
        self.style_img = style_transform(style).to(self.dev)


    def vgg_normalize(self, x):
        normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                         std=(0.229, 0.224, 0.225))
        x = x.div_(255.0)
        out = normalize(x)
        return out


    def gram_matrix(self, x):
        (batch, ch, h, w) = x.size() 
        img = x.view(batch, ch, h*w)
        img_transposed = img.transpose(1,2)
        gram_matrix = img.bmm(img_transposed) / (ch * h * w)
        return gram_matrix


    def load_model(self, model_path):
        if model_path != None : 
            self.transform_net.load_state_dict(torch.load(model_path, map_location=self.dev))


    def train(self, epochs, style_weight, content_weight, lr, save_path=None, model_path=None, show_every=20, save_every=100):

        self.mseloss = torch.nn.MSELoss()
        self.opt = torch.optim.Adam(self.transform_net.parameters(), lr)
        self.load_model(model_path)

        self.style_img = self.style_img.repeat(self.batch_size, 1, 1, 1)
        style_activations = self.loss_net(self.vgg_normalize(self.style_img))
        style_grams = [self.gram_matrix(y) for y in style_activations]

        for epoch in range(epochs):
            self.transform_net.train()
            print('EPOCH : ', epoch)

            for batch_id, (img, _) in enumerate(self.loader):
                n_batch = len(img)
                
                img = img.to(self.dev)
                yhat = self.transform_net(img)

                yhat_normalized = self.vgg_normalize(yhat)
                content_normalized = self.vgg_normalize(img)

                yhat_activations = self.loss_net(yhat_normalized)
                content_activations = self.loss_net(content_normalized)

                content_loss = content_weight * self.mseloss(yhat_activations.relu2_2, content_activations.relu2_2)

                style_loss = 0
                for yhat, style_gram in zip(yhat_activations, style_grams):
                    yhat_gram = self.gram_matrix(yhat)
                    style_loss += self.mseloss(yhat_gram, style_gram[:n_batch , : , :])
                style_loss *= style_weight

                total_loss = style_loss + content_loss

                self.opt.zero_grad()
                total_loss.backward(retain_graph=True)
                self.opt.step()

                if batch_id % show_every == 0 : 
                    print(f'Batch : {batch_id} | Content loss : {content_loss.item()} | Style loss : {style_loss.item()} | Total loss : {total_loss.item()}')

                if batch_id % save_every == 0 : 
                    if save_path != None : 
                        torch.save(self.transform_net.state_dict(), save_path)


In [5]:
dataset_path = '/content/drive/MyDrive/EE/coco'
style_img_path = '/content/drive/MyDrive/EE/styles/ooo.jpg'
EPOCHS = 10 
CONTENT_WEIGHT = 1e5
STYLE_WEIGHT = 8e8
LR = 0.001
SAVE_PATH = '/content/drive/MyDrive/EE/model2.pt'
MODEL_PATH = None

model = Model()
model.load_dataset(dataset_path=dataset_path, style_img_path=style_img_path)
model.train(model_path=MODEL_PATH, epochs=EPOCHS, style_weight=STYLE_WEIGHT, content_weight=CONTENT_WEIGHT, lr=LR, save_path=SAVE_PATH, show_every=10)

EPOCH :  0
Batch : 0 | Content loss : 603722.875 | Style loss : 3953227.75 | Total loss : 4556950.5
Batch : 10 | Content loss : 753697.9375 | Style loss : 3333485.5 | Total loss : 4087183.5
Batch : 20 | Content loss : 916206.375 | Style loss : 2678833.5 | Total loss : 3595040.0
Batch : 30 | Content loss : 1129492.0 | Style loss : 2193637.0 | Total loss : 3323129.0
Batch : 40 | Content loss : 1055943.25 | Style loss : 1845612.375 | Total loss : 2901555.5
Batch : 50 | Content loss : 1234030.875 | Style loss : 1455708.75 | Total loss : 2689739.5
Batch : 60 | Content loss : 1169994.125 | Style loss : 1316967.0 | Total loss : 2486961.0
Batch : 70 | Content loss : 1203522.75 | Style loss : 1135893.75 | Total loss : 2339416.5
Batch : 80 | Content loss : 1171420.5 | Style loss : 1007041.0625 | Total loss : 2178461.5
Batch : 90 | Content loss : 1121440.375 | Style loss : 971819.1875 | Total loss : 2093259.5
Batch : 100 | Content loss : 1268131.0 | Style loss : 836572.875 | Total loss : 2104704.

KeyboardInterrupt: ignored