In [1]:
import torch
from torch import nn
from PIL import Image, ImageDraw
import numpy as np
import random
import os
import wandb
from torch.utils.data import DataLoader
import PIL
import matplotlib.pyplot as plt

device = torch.device('cuda')

In [2]:
def image_to_tensor(image: Image) -> torch.Tensor:
    img_array = np.array(image)
    img_array = img_array/255
    img_array = img_array.transpose(2, 0, 1).astype(np.float32)
    img_tensor = torch.from_numpy(img_array)
    img_tensor = img_tensor.unsqueeze(0)
    return img_tensor

def tensor_to_image(tensor: torch.Tensor) -> Image.Image:
    """
    Convert a Tensor to a PIL Image.

    The tensor must have shape `[1, channels, height, width]` where the number of
    channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).

    Expected values are in the range `[0, 1]` and are clamped to this range.
    """
    assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
    num_channels = tensor.shape[1]
    tensor = tensor.clamp(0, 1).squeeze(0)

    match num_channels:
        case 1:
            tensor = tensor.squeeze(0)
        case 3 | 4:
            tensor = tensor.permute(1, 2, 0)
        case _:
            raise ValueError(f"Unsupported number of channels: {num_channels}")

    return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8"))  # type: ignore[reportUnknownType]

In [10]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, input_dim, nhead=2):
        super(TransformerEncoderBlock, self).__init__()
        self.transformer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=nhead)

    def forward(self, x):
        return self.transformer(x)

class TransformerDecoderBlock(nn.Module):
    def __init__(self, input_dim, nhead=2):
        super(TransformerDecoderBlock, self).__init__()
        self.transformer = nn.TransformerDecoderLayer(d_model=input_dim, nhead=nhead)

    def forward(self, x):
        return self.transformer(x)

In [13]:
class Encoder(nn.Module):
    def __init__(self, input_channels: int = 3):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        # self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        # self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.silu = nn.SiLU()
        self.maxpool = nn.MaxPool2d(2)

        self.flatten = nn.Flatten()
        # Transformer Encoder Layer
        # self.transformer = nn.TransformerEncoderLayer(d_model=64, nhead=2) 
        self.transformer_block = TransformerEncoderBlock(64)
        self.conv1_1 = nn.Conv2d(64, 4, 1, padding=0)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.conv1(x)
        y = self.silu(y)
        y = self.maxpool(y)
        y = self.conv2(y)
        y = self.silu(y)
        y = self.maxpool(y)

        # Apply transformer
        y = self.flatten(y)
        # y = y.unsqueeze(0)  # Add sequence dimension
        y = self.transformer_block(y)
        # y = y.squeeze(0)  # Remove sequence dimension

        # y = self.conv3(y)
        # y = self.silu(y)
        # y = self.maxpool(y)
        # y = self.conv4(y)
        # y = self.silu(y)
        # y = self.maxpool(y)
        y = self.conv1_1(y)
        return y

class Decoder(nn.Module):
    def __init__(self, output_channels: int = 3):
        super().__init__()
        # self.conv1 = nn.Conv2d(256, 128, 3, padding=1)
        # self.conv2 = nn.Conv2d(128, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 32, 3, padding=1)
        self.conv4 = nn.Conv2d(32, output_channels, 3, padding=1)
        self.silu = nn.SiLU()
        self.upsample = nn.Upsample(scale_factor=2)

        # Transformer Decoder Layer
        # self.transformer = nn.TransformerDecoderLayer(d_model=64, nhead=2)
        self.transformer_block = TransformerDecoderBlock(64)
        self.embedding = nn.Linear(64, 64)
    
        self.conv1_1 = nn.Conv2d(4, 64, 1, padding=0)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.conv1_1(x)

        # y = x.unsqueeze(0)  # Add sequence dimension
        y = self.transformer_block(y)
        # y = y.squeeze(0)  # Remove sequence dimension
        # y = self.embedding(y)
        # y = y.view(y.size(0), -1)
        # y = y.view(y.size(0), 64, 1, 1)

        
        # y = self.conv1(y)
        # y = self.silu(y)
        # y = self.upsample(y)
        # y = self.conv2(y)
        # y = self.silu(y)
        # y = self.upsample(y)
        y = self.conv3(y)
        y = self.silu(y)
        y = self.upsample(y)
        y = self.conv4(y)
        y = self.silu(y)
        y = self.upsample(y)
        return y
    
class AutoEncoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.decoder(self.encoder(x))
    

In [4]:
def generate_images(size, num_images, output_folder):
    # Create the output folder if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)

    for i in range(1, num_images + 1):
        # Create a new image with a random background color
        img = Image.new("RGB", size, color=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))

        # Get a drawing context
        draw = ImageDraw.Draw(img)

        # Choose a random shape (circle, square, or triangle)
        # shape = random.choice(["circle", "square", "triangle"])
        shape = "circle"

        # Choose a random position
        position = (random.randint(20, size[0]-20), random.randint(20, size[1]-20))

        # Choose a random color for the shape
        shape_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))

        # Draw the shape on the image
        if shape == "circle":
            draw.ellipse([position[0]-20, position[1]-20, position[0]+20, position[1]+20], fill=shape_color)
        elif shape == "square":
            draw.rectangle([position[0]-20, position[1]-20, position[0]+20, position[1]+20], fill=shape_color)
        elif shape == "triangle":
            draw.polygon([(position[0], position[1]-20), (position[0]-20, position[1]+20), (position[0]+20, position[1]+20)], fill=shape_color)

        # Save the image to the output folder
        img.save(os.path.join(output_folder, f"image_{i}.png"))

# Example usage
generate_images((128, 128), 10, "../data/dataset_train")
generate_images((128, 128), 1, "../data/dataset_test")

In [5]:
class Dataset:
    def __init__(self, path) -> None:
        self.data = list(range(100))
        self.path = path

    def __len__(self) -> int:
        return len(self.data)

    def __str__(self) -> str:
        return f'Dataset(len={len(self)})'

    def __repr__(self) -> str:
        return str(self)
    
    def __getitem__(self, key : str|int) -> int:
        match key:
            case key if isinstance(key, str):
                raise ValueError('Dataset does not take string as index.')
            case _:
                return self.data[key]

In [6]:
class ImageDataset:
    def __init__(self, path) -> None:
        self.path = path
        self.image_files = [f for f in os.listdir(path) if f.endswith(('.jpg', '.png', '.jpeg'))]
        self.data = [self.load_image(file) for file in self.image_files]

    def __len__(self) -> int:
        return len(self.data)

    def __str__(self) -> str:
        return f'ImageDataset(len={len(self)})'

    def __repr__(self) -> str:
        return str(self)

    def __getitem__(self, key: int) -> Image.Image:
        return self.data[key]

    def load_image(self, file: str) -> Image.Image:
        image_path = os.path.join(self.path, file)
        try:
            image = Image.open(image_path)
            return image
        except Exception as e:
            print(f"Error loading image '{file}': {e}")
            return None


In [None]:
lr = 1e-4
num_steps = 1000

# Initialize wandb
wandb.init(project='finegrain-cs', config={
    'lr': lr,
    'num_steps': num_steps
})

path_dataset_train = "../data/dataset_train/"
path_dataset_test = "../data/dataset_test/"
dataset_train = ImageDataset(path_dataset_train).data
dataset_test = ImageDataset(path_dataset_test).data
autoencoder = AutoEncoder()
autoencoder.train()


optimizer = torch.optim.Adam(autoencoder.parameters() , lr=lr)

# Convert your datasets to DataLoader for easy batch processing
# train_dataloader = DataLoader(dataset_train, batch_size=8, shuffle=True)

# Modify your training loop
for step in range(num_steps):
    loss_iter = 0
    for images in dataset_train:
        images = image_to_tensor(images)
        y = autoencoder(images)
        loss = (y - images).norm()
        loss_iter += loss

    loss = loss_iter / len(dataset_train)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    # Log the loss to wandb
    wandb.log({'step': step, 'loss': loss.item()})

    print(f"step {step} : loss {loss.item()}")

# Testing at the end of training
loss_iter_test = 0
reconstructed_images = []

with torch.no_grad():
    for images_test in dataset_test:
        images_test = image_to_tensor(images_test)
        y_test = autoencoder(images_test)
        loss_test = (y_test - images_test).norm()
        loss_iter_test += loss_test

        # # Append the reconstructed images for visualization
        concat = Image.new('RGB', (256, 128))
        concat.paste(tensor_to_image(images_test.data), (0, 0))
        concat.paste(tensor_to_image(y_test.data), (128, 0))
        
        reconstructed_images.append(concat)

images = [PIL.Image.fromarray(np.array(image)) for image in reconstructed_images]

wandb.log({"reconstructed_images": [wandb.Image(image) for image in images]})

loss_test = loss_iter_test / len(dataset_test)

# Log the test loss to wandb
wandb.log({'test_loss': loss_test.item()})

print(f"Test Loss at the end of training: {loss_test.item()}")


In [14]:
path_dataset_train = "../data/dataset_train/"
path_dataset_test = "../data/dataset_test/"
dataset_train = ImageDataset(path_dataset_train).data
dataset_test = ImageDataset(path_dataset_test).data
autoencoder = AutoEncoder()
autoencoder.train()
lr = 1e-4
num_steps = 100

optimizer = torch.optim.Adam(autoencoder.parameters() , lr=lr)
for step in range(num_steps):
    loss_iter = 0
    for image in dataset_train:
        image = image_to_tensor(image)
        y = autoencoder(image)
        loss = (y-image).norm()
        loss_iter += loss
    loss = loss_iter/len(dataset_train)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(f"step {step} : loss {loss.item()}")

autoencoder.eval()
for image in dataset_test:
    image = image_to_tensor(image)
    result = autoencoder(image)
    tensor_to_image(image.data).show()
    tensor_to_image(result.data).show()

AssertionError: query should be unbatched 2D or batched 3D tensor but received 4-D query tensor

In [None]:
tensor = image_to_tensor(dataset_test[0])
latent = autoencoder.encoder(tensor)
latent = (latent - latent.min())/(latent.max() - latent.min())


#split latent into 4 channels
channels = latent.squeeze(0).split(1)

images = [tensor_to_image(channel.unsqueeze(0).detach()) for channel in channels]

#write function to make grid using matplotlib in grayscale colorscheme
def make_grid_using_matplotlib(images):
    fig = plt.figure(figsize=(8, 8))

    columns = 2
    rows = 2
    for i in range(1, columns*rows +1):
        fig.add_subplot(rows, columns, i)

        plt.imshow(images[i-1], cmap='gray')
    plt.show()

make_grid_using_matplotlib(images)
