In [62]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision
from torchvision import datasets, transforms
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
from tqdm.notebook import tqdm
import logging
import sys
sys.path.append('/content/scripts')

logging.basicConfig(level=logging.WARNING)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}")

device=device(type='cuda')


## DataLoading

The VGG-19 model has been pretrained on ImageNet dataset, which will also be the choice for the decoder to learn from to appropriate the weights.

In [5]:
import zipfile

with zipfile.ZipFile('/content/data/flickr-8k-images-with-captions.zip', 'r') as zip_ref:
    zip_ref.extractall('/content/data')

In [11]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

dataset = datasets.ImageFolder('/content/data/Images', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

In [17]:
imgs, _ = next(iter(dataloader))
imgs = imgs * 255
imgs = imgs.permute(0, 2, 3, 1)
imgs = imgs.numpy()

fig = px.imshow(imgs[0])
fig.show('png')

In [19]:
%load_ext autoreload
%autoreload 2
from scripts.autoencoder import AutoEncoder
from scripts.styletransform import StyleTransformModel

In [35]:
vgg19 = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)
autoencoder = AutoEncoder(vgg19.features).to(device)
stylemodel = StyleTransformModel(vgg19.features).to(device)

In [41]:
criterion = nn.MSELoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)
ssim_metric = SSIM(data_range=1.).to(device)
train_loss = []
train_ssim = []

for epoch in tqdm(range(10)):
    for i, (imgs, _) in enumerate(dataloader):
        imgs = imgs.to(device)
        recon = autoencoder(imgs)
        loss = criterion(recon, imgs)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print(f"Epoch: {epoch+1} Batch: {i} Loss: {loss.item()}")
            train_loss.append(loss.item())
            train_ssim.append(ssim_metric(recon, imgs).item())

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 Batch: 0 Loss: 0.22848300635814667
Epoch: 1 Batch: 100 Loss: 0.07480225712060928
Epoch: 2 Batch: 0 Loss: 0.07244136184453964
Epoch: 2 Batch: 100 Loss: 0.07068870216608047
Epoch: 3 Batch: 0 Loss: 0.0743018388748169
Epoch: 3 Batch: 100 Loss: 0.07814919948577881
Epoch: 4 Batch: 0 Loss: 0.0786142572760582
Epoch: 4 Batch: 100 Loss: 0.07941234856843948
Epoch: 5 Batch: 0 Loss: 0.07464554160833359
Epoch: 5 Batch: 100 Loss: 0.07756352424621582
Epoch: 6 Batch: 0 Loss: 0.07365424931049347
Epoch: 6 Batch: 100 Loss: 0.06745589524507523
Epoch: 7 Batch: 0 Loss: 0.07657511532306671
Epoch: 7 Batch: 100 Loss: 0.07507982105016708
Epoch: 8 Batch: 0 Loss: 0.07244288921356201
Epoch: 8 Batch: 100 Loss: 0.07594071328639984
Epoch: 9 Batch: 0 Loss: 0.06372599303722382
Epoch: 9 Batch: 100 Loss: 0.07314694672822952
Epoch: 10 Batch: 0 Loss: 0.06959649920463562
Epoch: 10 Batch: 100 Loss: 0.0786690041422844


In [45]:
metrics = {'Epoch': list(range(len(train_loss))), 'Loss': train_loss, 'SSIM': train_ssim}

fig = px.line(metrics, x='Epoch', y=['Loss', 'SSIM'], title='Training Loss and SSIM',
              color_discrete_map={'Loss': 'palevioletred', 'SSIM': 'royalblue'})

fig.update_layout(
    xaxis_title="epoch", yaxis_title="metric"
)
fig.show()

In [47]:
pio.write_html(fig, file='/content/viz/autoencoder_metrics.html')

In [66]:
from PIL import Image

sample = transform(Image.open('/content/musk.jpg')).to(device)
autoencoder.eval()
with torch.no_grad():
    recon = autoencoder(sample)

recon = recon.cpu().permute(1, 2, 0).numpy() * 225
sample = sample.cpu().permute(1, 2, 0).numpy() * 225

fig = go.Figure()

# Add original image subplot
fig.add_trace(go.Image(z=sample, name='Original',
                       hoverinfo='name+x+y+z', xaxis='x', yaxis='y'))

# Add reconstructed image subplot
fig.add_trace(go.Image(z=recon, name='Reconstructed',
                       hoverinfo='name+x+y+z', xaxis='x2', yaxis='y2'))

# Update layout to display images side by side
fig.update_layout(
    title='Original and Reconstructed Images',
    xaxis=dict(domain=[0, 0.45]),  # Adjust domain for original image
    yaxis=dict(domain=[0, 1]),
    xaxis2=dict(domain=[0.55, 1]),  # Adjust domain for reconstructed image
    yaxis2=dict(domain=[0, 1]),
    margin=dict(l=0, r=0, t=40, b=0),
)

fig.show()
print(sample, recon)

[[[ 28.235296  59.11765  133.2353  ]
  [ 29.11765   60.000004 134.11765 ]
  [ 30.882355  61.76471  135.88235 ]
  ...
  [ 62.64706   84.70589  171.17647 ]
  [ 62.64706   85.588234 171.17647 ]
  [ 61.76471   84.70589  170.29413 ]]

 [[ 29.11765   60.000004 134.11765 ]
  [ 30.882355  61.76471  135.88235 ]
  [ 32.64706   63.529415 137.64706 ]
  ...
  [ 63.529415  86.47059  172.05882 ]
  [ 64.411766  87.35294  172.94118 ]
  [ 62.64706   85.588234 171.17647 ]]

 [[ 30.000002  60.882355 135.      ]
  [ 32.64706   63.529415 137.64706 ]
  [ 35.29412   66.176476 140.29413 ]
  ...
  [ 66.176476  89.117645 174.70589 ]
  [ 65.29412   88.2353   173.82353 ]
  [ 62.64706   85.588234 171.17647 ]]

 ...

 [[199.41177  146.4706   139.41177 ]
  [201.17647  149.11765  141.17648 ]
  [200.29411  149.11765  140.29413 ]
  ...
  [ 47.64706   27.352942  24.705883]
  [ 46.764706  26.470589  24.705883]
  [ 46.764706  28.235296  26.470589]]

 [[171.17647  117.35295  112.05882 ]
  [172.05882  120.00001  112.941185]
