In [2]:
from PIL import Image
import os
import torch
import hashlib
import tarfile
import requests
import numpy as np
from tqdm import tqdm
from PIL import Image
import torch
import torchvision
import matplotlib.pyplot as plt
from torchvision.transforms import v2
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from dataset import LFWDataset

In [4]:
class UNet2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet2, self).__init__()

        # Encoder
        self.enc_11 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        self.enc_12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Decoder
        self.dec_21 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.dec_22 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.upconv1 = nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        # Encoder
        x1 = F.relu(self.enc_11(x))
        x2 = F.relu(self.enc_12(x1))
        x = self.pool1(x2)

        # Decoder
        x = F.relu(self.dec_21(x))
        x = F.relu(self.dec_22(x))
        x = self.upconv1(x)

        return x

In [6]:
transform=None
train_dataset = LFWDataset(base_folder='..\cvdl_lab_4\lfw_dataset', transforms=transform, download=False, split_name='train')
test_dataset = LFWDataset(base_folder='..\cvdl_lab_4\lfw_dataset', transforms=transform, download=False, split_name='test')

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

model = UNet2(in_channels=3, out_channels=3)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

UNet2(
  (enc_11): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (enc_12): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dec_21): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (dec_22): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (upconv1): ConvTranspose2d(64, 3, kernel_size=(2, 2), stride=(2, 2))
)

In [7]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        inputs, targets = inputs.to(device).float(), targets.to(device)
        inputs = inputs.permute(0, 3, 1, 2)
        optimizer.zero_grad()

        targets_list = targets.tolist() if isinstance(targets, torch.Tensor) else targets

        # Map target values to valid class indices
        value_mapping = {29: 0, 76: 1, 150: 2}
        # Use torch.where to map values
        targets = torch.where(targets == 29, torch.tensor(value_mapping[29]), 
                            torch.where(targets == 76, torch.tensor(value_mapping[76]), 
                                        torch.tensor(value_mapping[150])))

        # Convert target tensor to Long
        targets = targets.long()
        outputs = model(inputs)
        #print(np.unique(targets))
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss:.4f}")

    # Validation loop
    model.eval()
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc="Validation"):
            inputs, targets = inputs.to(device).float(), targets.to(device)
            inputs = inputs.permute(0, 3, 1, 2)
            value_mapping = {29: 0, 76: 1, 150: 2}
            # Use torch.where to map values
            targets = torch.where(targets == 29, torch.tensor(value_mapping[29]), 
                            torch.where(targets == 76, torch.tensor(value_mapping[76]), 
                                        torch.tensor(value_mapping[150])))

            # Convert target tensor to Long
            targets = targets.long()
            outputs = model(inputs)
            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
        
            # Update total_samples and total_correct
            total_samples += targets.numel()
            total_correct += (predicted == targets).sum().item()


    # Print validation metrics
    print(f"Validation Accuracy: {total_correct / total_samples:.4f}")

Epoch 1/10: 100%|██████████████████████████████████████████████████████████████████████| 37/37 [08:27<00:00, 13.73s/it]


Epoch 1/10, Loss: 0.9846


Validation: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:31<00:00,  3.19s/it]


Validation Accuracy: 0.7730


Epoch 2/10: 100%|██████████████████████████████████████████████████████████████████████| 37/37 [08:36<00:00, 13.95s/it]


Epoch 2/10, Loss: 0.5772


Validation: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:31<00:00,  3.10s/it]


Validation Accuracy: 0.7912


Epoch 3/10: 100%|██████████████████████████████████████████████████████████████████████| 37/37 [07:25<00:00, 12.03s/it]


Epoch 3/10, Loss: 0.5534


Validation: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:29<00:00,  2.91s/it]


Validation Accuracy: 0.7945


Epoch 4/10: 100%|██████████████████████████████████████████████████████████████████████| 37/37 [07:21<00:00, 11.94s/it]


Epoch 4/10, Loss: 0.5519


Validation: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:28<00:00,  2.89s/it]


Validation Accuracy: 0.7978


Epoch 5/10: 100%|██████████████████████████████████████████████████████████████████████| 37/37 [07:24<00:00, 12.00s/it]


Epoch 5/10, Loss: 0.5281


Validation: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:28<00:00,  2.88s/it]


Validation Accuracy: 0.8003


Epoch 6/10: 100%|██████████████████████████████████████████████████████████████████████| 37/37 [07:21<00:00, 11.93s/it]


Epoch 6/10, Loss: 0.5267


Validation: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:28<00:00,  2.87s/it]


Validation Accuracy: 0.8015


Epoch 7/10: 100%|██████████████████████████████████████████████████████████████████████| 37/37 [07:22<00:00, 11.97s/it]


Epoch 7/10, Loss: 0.5149


Validation: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:28<00:00,  2.88s/it]


Validation Accuracy: 0.8019


Epoch 8/10: 100%|██████████████████████████████████████████████████████████████████████| 37/37 [07:21<00:00, 11.93s/it]


Epoch 8/10, Loss: 0.5164


Validation: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:28<00:00,  2.87s/it]


Validation Accuracy: 0.7992


Epoch 9/10: 100%|██████████████████████████████████████████████████████████████████████| 37/37 [07:32<00:00, 12.23s/it]


Epoch 9/10, Loss: 0.4998


Validation: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:30<00:00,  3.05s/it]


Validation Accuracy: 0.8111


Epoch 10/10: 100%|█████████████████████████████████████████████████████████████████████| 37/37 [08:18<00:00, 13.48s/it]


Epoch 10/10, Loss: 0.4871


Validation: 100%|██████████████████████████████████████████████████████████████████████| 10/10 [00:32<00:00,  3.29s/it]

Validation Accuracy: 0.8106





In [8]:
scripted_model = torch.jit.script(model)
scripted_model.save("final_scripted_model.pt")

In [13]:
from io import BytesIO
import requests
import gradio as gr

In [14]:
my_model = torch.jit.load("final_scripted_model.pt")

In [15]:
def segment_image(input_image):
    image_array = np.array(input_image)
    image_tensor = torch.unsqueeze(torch.transpose(torch.tensor(input_image), 0, 2), 0).float()
    
    with torch.no_grad():
        output = my_model(image_tensor)

    predicted_mask = torch.argmax(output, dim=1).squeeze().numpy()
    predicted_mask_rotated = np.rot90(predicted_mask, k=-1)
    predicted_mask_final = np.fliplr(predicted_mask_rotated)
    
    predicted_mask_final = (predicted_mask_final * 255).astype(np.uint8)
    return predicted_mask_final

# Gradio Interface
ui = gr.Interface(
    fn=segment_image, 
    inputs=gr.Image(),
    outputs=gr.Image(),
    examples=['celeb_inputs/celeb1.jpg', 
              'celeb_inputs/celeb2.jpg', 
              'celeb_inputs/celeb3.jpg',
              'celeb_inputs/celeb4.jpg',
              'celeb_inputs/celeb5.jpg']
)

In [16]:
ui.launch()

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




In [17]:
transform=None
train_dataset = LFWDataset(base_folder='..\cvdl_lab_4\lfw_dataset', transforms=transform, download=False, split_name='train')
test_dataset = LFWDataset(base_folder='..\cvdl_lab_4\lfw_dataset', transforms=transform, download=False, split_name='test')

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

model = UNet2(in_channels=3, out_channels=3)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

UNet2(
  (enc_11): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (enc_12): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dec_21): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (dec_22): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (upconv1): ConvTranspose2d(64, 3, kernel_size=(2, 2), stride=(2, 2))
)

In [18]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        inputs, targets = inputs.to(device).float(), targets.to(device)
        inputs = inputs.permute(0, 3, 1, 2)
        optimizer.zero_grad()

        targets_list = targets.tolist() if isinstance(targets, torch.Tensor) else targets

        # Map target values to valid class indices
        value_mapping = {29: 0, 76: 1, 150: 2}
        # Use torch.where to map values
        targets = torch.where(targets == 29, torch.tensor(value_mapping[29]), 
                            torch.where(targets == 76, torch.tensor(value_mapping[76]), 
                                        torch.tensor(value_mapping[150])))

        # Convert target tensor to Long
        targets = targets.long()
        outputs = model(inputs)
        #print(np.unique(targets))
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss:.4f}")

    # Validation loop
    model.eval()
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc="Validation"):
            inputs, targets = inputs.to(device).float(), targets.to(device)
            inputs = inputs.permute(0, 3, 1, 2)
            value_mapping = {29: 0, 76: 1, 150: 2}
            # Use torch.where to map values
            targets = torch.where(targets == 29, torch.tensor(value_mapping[29]), 
                            torch.where(targets == 76, torch.tensor(value_mapping[76]), 
                                        torch.tensor(value_mapping[150])))

            # Convert target tensor to Long
            targets = targets.long()
            outputs = model(inputs)
            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
        
            # Update total_samples and total_correct
            total_samples += targets.numel()
            total_correct += (predicted == targets).sum().item()


    # Print validation metrics
    print(f"Validation Accuracy: {total_correct / total_samples:.4f}")

Epoch 1/10: 100%|████████████████████████████████████████████████████████████████████| 586/586 [07:02<00:00,  1.39it/s]


Epoch 1/10, Loss: 0.6047


Validation: 100%|████████████████████████████████████████████████████████████████████| 147/147 [00:31<00:00,  4.62it/s]


Validation Accuracy: 0.7905


Epoch 2/10: 100%|████████████████████████████████████████████████████████████████████| 586/586 [07:28<00:00,  1.31it/s]


Epoch 2/10, Loss: 0.5222


Validation: 100%|████████████████████████████████████████████████████████████████████| 147/147 [00:32<00:00,  4.58it/s]


Validation Accuracy: 0.8039


Epoch 3/10: 100%|████████████████████████████████████████████████████████████████████| 586/586 [07:28<00:00,  1.31it/s]


Epoch 3/10, Loss: 0.4984


Validation: 100%|████████████████████████████████████████████████████████████████████| 147/147 [00:33<00:00,  4.37it/s]


Validation Accuracy: 0.7794


Epoch 4/10: 100%|████████████████████████████████████████████████████████████████████| 586/586 [07:27<00:00,  1.31it/s]


Epoch 4/10, Loss: 0.4712


Validation: 100%|████████████████████████████████████████████████████████████████████| 147/147 [00:31<00:00,  4.69it/s]


Validation Accuracy: 0.8181


Epoch 5/10: 100%|████████████████████████████████████████████████████████████████████| 586/586 [07:23<00:00,  1.32it/s]


Epoch 5/10, Loss: 0.4546


Validation: 100%|████████████████████████████████████████████████████████████████████| 147/147 [00:32<00:00,  4.52it/s]


Validation Accuracy: 0.8342


Epoch 6/10: 100%|████████████████████████████████████████████████████████████████████| 586/586 [07:27<00:00,  1.31it/s]


Epoch 6/10, Loss: 0.4470


Validation: 100%|████████████████████████████████████████████████████████████████████| 147/147 [00:32<00:00,  4.59it/s]


Validation Accuracy: 0.8356


Epoch 7/10: 100%|████████████████████████████████████████████████████████████████████| 586/586 [07:25<00:00,  1.32it/s]


Epoch 7/10, Loss: 0.4432


Validation: 100%|████████████████████████████████████████████████████████████████████| 147/147 [00:31<00:00,  4.70it/s]


Validation Accuracy: 0.8340


Epoch 8/10: 100%|████████████████████████████████████████████████████████████████████| 586/586 [07:25<00:00,  1.31it/s]


Epoch 8/10, Loss: 0.4413


Validation: 100%|████████████████████████████████████████████████████████████████████| 147/147 [00:32<00:00,  4.50it/s]


Validation Accuracy: 0.8387


Epoch 9/10: 100%|████████████████████████████████████████████████████████████████████| 586/586 [07:23<00:00,  1.32it/s]


Epoch 9/10, Loss: 0.4328


Validation: 100%|████████████████████████████████████████████████████████████████████| 147/147 [00:31<00:00,  4.70it/s]


Validation Accuracy: 0.8419


Epoch 10/10: 100%|███████████████████████████████████████████████████████████████████| 586/586 [07:20<00:00,  1.33it/s]


Epoch 10/10, Loss: 0.4302


Validation: 100%|████████████████████████████████████████████████████████████████████| 147/147 [00:31<00:00,  4.67it/s]

Validation Accuracy: 0.8412





In [19]:
scripted_model = torch.jit.script(model)
scripted_model.save("var2_scripted_model.pt")