<a href="https://colab.research.google.com/github/cyteena/U-net/blob/main/Unet_more_robust.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install lpips

Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB)
Downloading lpips-0.1.4-py3-none-any.whl (53 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lpips
Successfully installed lpips-0.1.4


In [None]:
from google.colab import files

upload = files.upload()

Saving data_sunset.jpg to data_sunset.jpg


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
from PIL import Image
import lpips
from collections import deque

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Unet(nn.Module):
    def __init__(self, num_channels, in_channels=6, out_channels=2):
        super().__init__()

        # Encoder layers
        self.encoder_layers = nn.ModuleList()
        self.encoder_layers.append(self.conv_block(in_channels, num_channels[0])) # 6 --> 64
        for i in range(len(num_channels) - 2): #leave the last layer
            self.encoder_layers.append(self.conv_block(num_channels[i], num_channels[i + 1])) # 64 --> 128
            # encoder_layer: [6 --> 64, 64 --> 128]

        # Bottleneck
        self.bottleneck = self.conv_block(num_channels[-2], num_channels[-1]) # 128 --> 256

        # Decoder layers
        self.decoder_layers = nn.ModuleList()
        self.upconv_layers = nn.ModuleList()
        for i in range(len(num_channels) - 1, 0, -1): # range(3,0,-1) = [3, 2, 1]
            self.upconv_layers.append(nn.ConvTranspose2d(num_channels[i], num_channels[i - 1], kernel_size=2, stride=2)) # 256 --> 128 # 128 --> 64
            self.decoder_layers.append(self.conv_block(num_channels[i], num_channels[i - 1])) # 256 --> 128 # 128 --> 64

        # Final layer
        self.output_layer = nn.Conv2d(num_channels[0], out_channels, kernel_size=1, padding=0) # 64 --> 2

    def conv_block(self, in_channels:int, out_channels:int):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    # def crop_tensor(self, tensor, target_tensor):
    #     """
    #     Crop the target_tensor to match the size of the tensor.
    #     """
    #     target_size = tensor.size()[-2:]
    #     _, _, h, w = target_tensor.size()
    #     crop_h = (h - target_size[0]) // 2
    #     crop_w = (w - target_size[1]) // 2
    #     return target_tensor[:, :, crop_h:crop_h + target_size[0], crop_w:crop_w + target_size[1]]

    def forward(self, x):
        # Encoder
        encoder_outputs = []
        for layer in self.encoder_layers: # encoder_layers = [6 --> 64, 64 --> 128, 128 --> 256]
            x = layer(x)
            encoder_outputs.append(x) # encoder_outputs = [64, 128, 256] --> [num_channels[-4], num_channels[-3], num_channels[-2]]
            x = nn.MaxPool2d(kernel_size=2, stride=2)(x)

        # Bottleneck
        x = self.bottleneck(x) # 256 --> 512

        # Decoder
        # upconv_layers = [512 --> 256, 256 --> 128, 128 --> 64] and decoder_layers = [512 --> 256, 256 --> 128, 128 --> 64]
        for i in range(len(self.decoder_layers)): # range(3) = [0, 1, 2]
            x = self.upconv_layers[i](x)
            x = torch.cat([encoder_outputs[-(i + 1)], x], dim=1)
            # encoder_output[-(i + 1)] == num_channels[-(i + 2)]
            x = self.decoder_layers[i](x)

        # Final layer
        x = self.output_layer(x)
        return x


num_channels = [64, 128, 256, 512]

model_unet = Unet(num_channels).to(device)

transform = transforms.Compose([
    transforms.Resize((572, 572)),
    transforms.ToTensor()
])

loss_fn = lpips.LPIPS(net = "alex").to(device)


def image_to_tensor(image_path):
    image = Image.open(image_path).convert('RGB')
    tensor = transform(image).unsqueeze(0)  # Add batch dimension
    return tensor

def apply_deformation_field(deformation_field, image):
    _, _, H, W = image.shape

    grid_y, grid_x = torch.meshgrid(torch.linspace(-1, 1, H), torch.linspace(-1, 1, W))
    grid = torch.stack((grid_x, grid_y), 2).unsqueeze(0).to(image.device)

    deformation_field = deformation_field.permute(0,2,3,1)
    grid = grid + deformation_field

    deformed_image = F.grid_sample(image, grid, mode="bilinear", padding_mode="reflection", align_corners=True)
    return deformed_image



def train_step(model, content_image, style_image, optimizer, loss_fn):
    model.train()
    optimizer.zero_grad()

    if len(content_image.shape) == 3:
        content_image = content_image.unsqueeze(0)
    if len(style_image.shape) == 3:
        style_image = style_image.unsqueeze(0)

    mixed_image = torch.cat([content_image, style_image], dim = 1)
    deformation_field = model(mixed_image)

    deformed_style_image = apply_deformation_field(deformation_field, style_image)

    lpips_loss = loss_fn(deformed_style_image, content_image)

    lpips_loss.backward()
    optimizer.step()

    return lpips_loss.item()

def main(content_image_path, style_image_path):
    content = image_to_tensor(content_image_path).to(device)
    style = image_to_tensor(style_image_path).to(device)
    return content, style

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


In [None]:
content_image = "/content/data_cute_cat.jpg"
style_image = "/content/data_star_night.jpg"

content = image_to_tensor(content_image).to(device)
style = image_to_tensor(style_image).to(device)

In [None]:
content.shape

torch.Size([1, 3, 256, 256])

This code show how to turn a tensor (N, C, H, W)

into an image

In [None]:
to_pil_image = transforms.ToPILImage()
image = to_pil_image(content.squeeze(0))

In [None]:
outputs =[]

In [None]:

optimizer = optim.Adam(model_unet.parameters(), lr =3e-4, weight_decay=1e-5)

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 100, eta_min=1e-7)

num_epoches = 1000

recent_losses = deque(maxlen = 50)

for epoch in range(num_epoches):
    loss = train_step(model_unet, content, style, optimizer, loss_fn)
    recent_losses.append(loss)

    if (epoch) % 50 == 0:
        avg_loss = sum(recent_losses) / len(recent_losses)
        print(f'Epoch {epoch}, Average Loss: {avg_loss}')
        mixed = torch.cat([content, style], dim = 1)

    if epoch % 100 == 0:
        model_unet.eval()
        deformation_field = model_unet(mixed)
        output = apply_deformation_field(deformation_field, style)
        outputs.append(output)
        model_unet.train()


    scheduler.step()

Epoch 0, Average Loss: 1.0835273265838623
Epoch 50, Average Loss: 1.0117256784439086
Epoch 100, Average Loss: 0.8771947336196899
Epoch 150, Average Loss: 0.839885458946228
Epoch 200, Average Loss: 0.891559261083603
Epoch 250, Average Loss: 0.8821085667610169
Epoch 300, Average Loss: 0.7899301671981811
Epoch 350, Average Loss: 0.7711140191555024
Epoch 400, Average Loss: 0.80642462849617
Epoch 450, Average Loss: 0.8026821494102478
Epoch 500, Average Loss: 0.7367021298408508
Epoch 550, Average Loss: 0.7224828255176544
Epoch 600, Average Loss: 0.7575263416767121
Epoch 650, Average Loss: 0.7784326756000519
Epoch 700, Average Loss: 0.719307199716568
Epoch 750, Average Loss: 0.7061834621429444
Epoch 800, Average Loss: 0.7571823573112488
Epoch 850, Average Loss: 0.7570263671875
Epoch 900, Average Loss: 0.6922887241840363
Epoch 950, Average Loss: 0.6765922439098359


In [None]:
mixed = torch.cat([content, style], dim = 1)
deformation_field = model_unet(mixed)
output = apply_deformation_field(deformation_field, style)

In [None]:
output.shape

torch.Size([1, 3, 256, 256])

In [None]:
image = to_pil_image(output.squeeze(0))
image.save("output4.jpg")

In [None]:
len(outputs)

20

In [None]:
for i in range(len(outputs)):
    image = to_pil_image(outputs[i].squeeze(0))
    image.save(f"/content/outputs/output_{i}.jpg")

In [None]:
# prompt: convert the output_0 ~ output_19  into a gif

import imageio
import os

# Create a list to store the image frames
image_frames = []

# Iterate through the output image files and append them to the list
for i in range(len(outputs)):
    image_path = f"/content/outputs/output_{i}.jpg"
    if os.path.exists(image_path):  # Check if the file exists
        image_frames.append(imageio.imread(image_path))
    else:
        print(f"Warning: File {image_path} not found. Skipping.")

# Save the image frames as an animated GIF
imageio.mimsave('output.gif', image_frames, fps=2) # Adjust fps as needed

  image_frames.append(imageio.imread(image_path))
