In [1]:
import torch
import torch.nn as nn
from sketch_loader import Sketch_Data
from models import AE, VAE
from torch.utils.data import DataLoader


trainset = Sketch_Data("Sketch_Anime/train")
validset = Sketch_Data("Sketch_Anime/val")

batchsize = 32
epochs = 180
lr = 1e-3
train_only = False

train_loader = DataLoader(trainset, batch_size = batchsize, shuffle = True)
val_loader = DataLoader(validset, batch_size = batchsize, shuffle = True)

device = "cuda:0"
model = AE().to(device)

loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

KeyboardInterrupt: 

In [None]:
from PIL import Image
from torchvision.transforms import ToPILImage
from tqdm import tqdm
import os
from utils import save_example

# Ensure the directory exists for saving
output_dir = 'output_images'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)


for epoch in range(epochs):
    model.train()
    bi = 0
    train_loss = 0
    for x, y in tqdm(train_loader):
        bi+=1
        x, y = x.to(device), y.to(device).float()
        optimizer.zero_grad()
        output = model(x).float()
        loss = loss_function(output, y)
        loss.backward()
        train_loss += loss
        optimizer.step()
        # if bi == 200:
        #     save_example(output, x, y, 0)

    model.eval()
    test_loss = 0
    if not train_only:
        with torch.no_grad():
            for batch_idx, (x, y) in enumerate(tqdm(val_loader)):
                x, y = x.to(device), y.to(device).float()
                output = model(x).float()
                loss = loss_function(output, y)
                test_loss += loss
                
                # Save the first batch's images as examples
                if batch_idx == 0:
                    save_example(epoch, output, x, y, 0, output_dir)

    print(f"Epoch {epoch+1}: train loss = {train_loss}, test loss = {test_loss}")
    torch.save(model, 'last.pt')


In [1]:
from PIL import Image
from torchvision.transforms import ToPILImage
from tqdm import tqdm
import os
from utils import save_example
import torch
import torch.nn as nn
from sketch_loader import Sketch_Data
from models import AE, VAE
from torch.utils.data import DataLoader


trainset = Sketch_Data("Sketch_Anime/train")
validset = Sketch_Data("Sketch_Anime/val")

batchsize = 2
epochs = 180
lr = 1e-5
train_only = False

train_loader = DataLoader(trainset, batch_size = batchsize, shuffle = True)
val_loader = DataLoader(validset, batch_size = batchsize, shuffle = True)

device = "cuda:0"
model = VAE().to(device)

loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

output_dir = 'output_images'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

def kl_criterion(mu, logvar, batch_size):
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    KLD /= batch_size
    return KLD

beta = 0
max_beta = 0.1
step = 0.01

for epoch in range(epochs):
    model.train()
    train_loss = 0
    for x, y in tqdm(train_loader):
        x, y = x.to(device), y.to(device).float()
        optimizer.zero_grad()
        output, mu, logvar = model(x, y)
        loss = loss_function(output.float(), y) + beta * kl_criterion(mu, logvar, batchsize)
        loss.backward()
        train_loss += loss
        optimizer.step()

    model.eval()
    test_loss = 0
    if not train_only:
        with torch.no_grad():
            for batch_idx, (x, y) in enumerate(tqdm(val_loader)):
                x, y = x.to(device), y.to(device).float()
                output, _, _ = model(x, y, False)
                loss = loss_function(output.float(), y)
                test_loss += loss
                
                # Save the first batch's images as examples
                if batch_idx == 0:
                    save_example(epoch, output, x, y, 0, output_dir)

    if beta < max_beta:
        beta+= step
    print(f"Epoch {epoch+1}: train loss = {train_loss}, test loss = {test_loss}")
    torch.save(model, 'VAE.pt')


  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 7112/7112 [1:02:53<00:00,  1.88it/s]
100%|██████████| 1773/1773 [03:29<00:00,  8.45it/s]


Epoch 1: train loss = 333.9888000488281, test loss = 163.2249755859375


100%|██████████| 7112/7112 [1:01:09<00:00,  1.94it/s]
100%|██████████| 1773/1773 [03:29<00:00,  8.46it/s]


Epoch 2: train loss = 719186.5625, test loss = 71.29034423828125


100%|██████████| 7112/7112 [1:03:24<00:00,  1.87it/s]
100%|██████████| 1773/1773 [03:42<00:00,  7.96it/s]


Epoch 3: train loss = 7502.4775390625, test loss = 66.60196685791016


 23%|██▎       | 1667/7112 [14:30<47:24,  1.91it/s]  


KeyboardInterrupt: 