In [1]:
import torch
from torch import nn
from PIL import Image, ImageDraw
import numpy as np
import random
import os
import wandb
from torch.utils.data import DataLoader
import PIL
import matplotlib.pyplot as plt

from refiners.fluxion import layers as fl
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.context import Contexts
from refiners.fluxion import utils

device = torch.device('cuda')

In [2]:
class DownBlock(fl.Chain):
    def __init__(self) -> None:
        super().__init__(
            fl.Conv2d(32, 64, 3, padding=1),
            fl.Lambda(self.append_residual),
            fl.Downsample(64, 2),
            fl.Conv2d(64, 128, 3, padding=1),
            fl.Lambda(self.append_residual),
            fl.Downsample(128, 2),
            fl.Conv2d(128, 256, 3, padding=1),
            fl.Lambda(self.append_residual),
        )

    def append_residual(self, x: torch.Tensor):
        self.use_context("unet")["residuals"].append(x)
        return x


class MiddleBlock(fl.Residual):
    def __init__(self) -> None:
        super().__init__(
            fl.Conv2d(256, 256, 3, padding=1),
        )


class UpBlock(fl.Chain):
    def __init__(self) -> None:
        super().__init__(
            fl.Concatenate(
                fl.Identity(),
                fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals.pop(-1)),
                dim=1,
            ),
            fl.Conv2d(in_channels=512, out_channels=128, kernel_size=3, padding=1),
            fl.Upsample(channels=128),
            fl.Concatenate(
                fl.Identity(),
                fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals.pop(-1)),
                dim=1,
            ),
            fl.Conv2d(in_channels=256, out_channels=64, kernel_size=3, padding=1),
            fl.Upsample(channels=64),
            fl.Concatenate(
                fl.Identity(),
                fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals.pop(-1)),
                dim=1,
            ),
            fl.Conv2d(in_channels=128, out_channels=32, kernel_size=3, padding=1),
        )


class UNet(fl.Chain):
    def __init__(self) -> None:
        super().__init__(
            fl.Conv2d(in_channels=3, out_channels=32, kernel_size=1),
            DownBlock(),
            MiddleBlock(),
            UpBlock(),
            fl.Conv2d(in_channels=32, out_channels=3, kernel_size=1),
        )

    def init_context(self) -> Contexts:
        return {"sampling": {"shapes": []}, "unet": {"residuals": []}}


unet = UNet()
x = torch.randn(1, 3, 32, 32)

print(repr(unet))
print(unet(x).shape)

(CHAIN) UNet()
    ├── Conv2d(in_channels=3, out_channels=32, kernel_size=(1, 1)) #1
    ├── (CHAIN) DownBlock()
    │   ├── Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=(1, 1)) #1
    │   ├── Lambda(append_residual(x: torch.Tensor)) #1
    │   ├── (CHAIN) Downsample(channels=64, scale_factor=2) #1
    │   │   ├── SetContext(context=sampling, key=shapes)
    │   │   ├── Lambda(<lambda>(x))
    │   │   └── Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2))
    │   ├── Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), padding=(1, 1)) #2
    │   ├── Lambda(append_residual(x: torch.Tensor)) #2
    │   ├── (CHAIN) Downsample(channels=128, scale_factor=2) #2
    │   │   ├── SetContext(context=sampling, key=shapes)
    │   │   ├── Lambda(<lambda>(x))
    │   │   └── Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2))
    │   ├── Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), padding=(1, 1)) #3
  

In [3]:
class Resblock(fl.Sum):
    def __init__(self, in_channels: int=1, out_channels: int=1) -> None:
        super().__init__(
            fl.Chain(
                fl.Conv2d(in_channels, out_channels, 3, padding=1),
                fl.SiLU(),
                fl.Conv2d(out_channels, out_channels, 3, padding=1),
            ),
            fl.Conv2d(in_channels, out_channels, 3, padding=1),
        )

In [4]:
class Dropout(nn.Dropout, fl.Module):
    def __init__(self, probability: float = 0.5, inplace: bool = False) -> None:
        super().__init__(p=probability, inplace=inplace)

class MaxPool2d(nn.MaxPool2d, fl.Module):
    def __init__(self, factor: int = 2) -> None:
        super().__init__(factor)

In [5]:
class Encoder(fl.Chain):
    def __init__(self, input_channels: int = 3):
        super().__init__(
            fl.Conv2d(input_channels, 32, 1, padding=0),
            Resblock(32, 64),
            MaxPool2d(2),
            Resblock(64, 128),
            MaxPool2d(2),
            Resblock(128, 256),
            MaxPool2d(2),
            fl.Conv2d(256, 4, 1, padding=0),
        )

class Decoder(fl.Chain):
    def __init__(self, output_channels: int = 3):
        super().__init__(
            fl.Conv2d(4, 256, 1, padding=0),
            Resblock(256, 128),
            fl.Upsample(channels=128,upsample_factor=2),
            Resblock(128, 64),
            fl.Upsample(channels=64,upsample_factor=2),
            Resblock(64, 32),
            fl.Upsample(channels = 32,upsample_factor=2),
            Resblock(32, output_channels),
            fl.Conv2d(output_channels, output_channels, 1, padding=0),
        )

In [6]:
class AutoEncoder(fl.Chain):
    def __init__(self) -> None:
        super().__init__(
            Encoder(),
            Decoder(),
        )

In [7]:
class DropoutAdapter(Adapter[fl.SiLU], fl.Chain):
    def __init__(self, target: fl.SiLU, dropout: float = 0.5):
        self.dropout = dropout
        with self.setup_adapter(target):
            super().__init__(target)

    def inject(self, parent: fl.Chain | None = None):
        self.append(Dropout(self.dropout))
        super().inject(parent)

    def eject(self):
        dropout = self.ensure_find(Dropout)
        #  ensure_find : meme chose que find mais ne peut pas renovyer None
        self.remove(dropout)
        super().eject()
        

In [8]:
resnet = Resblock(128, 128)
silu = resnet.Chain.SiLU
adapter = DropoutAdapter(silu, dropout=0.5)

adapter.inject(resnet.Chain)
print(repr(resnet))
print(repr(adapter))
adapter.eject()
print(repr(resnet))
print(repr(adapter))

(SUM) Resblock()
    ├── (CHAIN)
    │   ├── Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=(1, 1)) #1
    │   ├── (CHAIN) DropoutAdapter()
    │   │   ├── SiLU()
    │   │   └── Dropout()
    │   └── Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=(1, 1)) #2
    └── Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=(1, 1))
(CHAIN) DropoutAdapter()
    ├── SiLU()
    └── Dropout()
(SUM) Resblock()
    ├── (CHAIN)
    │   ├── Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=(1, 1)) #1
    │   ├── SiLU()
    │   └── Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=(1, 1)) #2
    └── Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=(1, 1))
(CHAIN) DropoutAdapter()
    └── SiLU()


In [9]:
def load_dropout(chain : fl.Chain, dropout : float = 0.5):
    for silu, parent in chain.walk(fl.SiLU):
        DropoutAdapter(silu, dropout).inject(parent)

In [10]:
def generate_images(size, num_images, output_folder):
    # Create the output folder if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)

    for i in range(1, num_images + 1):
        # Create a new image with a random background color
        img = Image.new("RGB", size, color=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))

        # Get a drawing context
        draw = ImageDraw.Draw(img)

        # Choose a random shape (circle, square, or triangle)
        # shape = random.choice(["circle", "square", "triangle"])
        shape = "square"

        # Choose a random position
        position = (random.randint(20, size[0]-20), random.randint(20, size[1]-20))

        # Choose a random color for the shape
        shape_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))

        # Draw the shape on the image
        if shape == "circle":
            draw.ellipse([position[0]-20, position[1]-20, position[0]+20, position[1]+20], fill=shape_color)
        elif shape == "square":
            draw.rectangle([position[0]-20, position[1]-20, position[0]+20, position[1]+20], fill=shape_color)
        elif shape == "triangle":
            draw.polygon([(position[0], position[1]-20), (position[0]-20, position[1]+20), (position[0]+20, position[1]+20)], fill=shape_color)

        # Save the image to the output folder
        img.save(os.path.join(output_folder, f"image_{i}.png"))

# Example usage
generate_images((128, 128), 100, "../data/dataset_train")
generate_images((128, 128), 20, "../data/dataset_test")

In [11]:
class Dataset:
    def __init__(self, path) -> None:
        self.data = list(range(100))
        self.path = path

    def __len__(self) -> int:
        return len(self.data)

    def __str__(self) -> str:
        return f'Dataset(len={len(self)})'

    def __repr__(self) -> str:
        return str(self)
    
    def __getitem__(self, key : str|int) -> int:
        match key:
            case key if isinstance(key, str):
                raise ValueError('Dataset does not take string as index.')
            case _:
                return self.data[key]
            
class ImageDataset:
    def __init__(self, path) -> None:
        self.path = path
        self.image_files = [f for f in os.listdir(path) if f.endswith(('.jpg', '.png', '.jpeg'))]
        self.data = [self.load_image(file) for file in self.image_files]

    def __len__(self) -> int:
        return len(self.data)

    def __str__(self) -> str:
        return f'ImageDataset(len={len(self)})'

    def __repr__(self) -> str:
        return str(self)

    def __getitem__(self, key: int) -> Image.Image:
        return self.data[key]

    def load_image(self, file: str) -> Image.Image:
        image_path = os.path.join(self.path, file)
        try:
            image = Image.open(image_path)
            return image
        except Exception as e:
            print(f"Error loading image '{file}': {e}")
            return None

In [80]:
path_dataset_train = "../data/dataset_train/"
path_dataset_test = "../data/dataset_test/"
dataset_train = ImageDataset(path_dataset_train).data
dataset_test = ImageDataset(path_dataset_test).data
autoencoder = AutoEncoder()
load_dropout(autoencoder, dropout=0.5)
autoencoder.to(device)
autoencoder.train()
lr = 1e-4
num_epochs = 100

optimizer = torch.optim.Adam(autoencoder.parameters() , lr=lr)
for epoch in range(num_epochs):
    loss_iter = 0
    for image in dataset_train:
        image = utils.image_to_tensor(image).to(device)
        y = autoencoder(image)
        loss = (y-image).norm()
        loss_iter += loss
    loss = loss_iter/len(dataset_train)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(f"epoch {epoch} : loss {loss.item()}")

autoencoder.eval()
for image in dataset_test:
    image = utils.image_to_tensor(image).to(device)
    result = autoencoder(image)
    utils.tensor_to_image(image.data).show()
    utils.tensor_to_image(result.data).show()

NameError: name 'a' is not defined

In [13]:
lr = 1e-4
num_epochs = 1000

# Initialize wandb
wandb.init(project='autoencoder', entity = "finegrain-cs", name = '1000epoch, refiners, squares', config={
    'lr': lr,
    'num_epochs': num_epochs
})

path_dataset_train = "../data/dataset_train/"
path_dataset_test = "../data/dataset_test/"
dataset_train = ImageDataset(path_dataset_train).data
dataset_test = ImageDataset(path_dataset_test).data
autoencoder = AutoEncoder()
load_dropout(autoencoder, dropout=0.5)
autoencoder.to(device)
autoencoder.train()


optimizer = torch.optim.Adam(autoencoder.parameters() , lr=lr)

# Convert your datasets to DataLoader for easy batch processing
# train_dataloader = DataLoader(dataset_train, batch_size=8, shuffle=True)

# Modify your training loop
for epoch in range(num_epochs):
    loss_iter = 0
    for images in dataset_train:
        images = utils.image_to_tensor(images).to(device)
        y = autoencoder(images)
        loss = (y - images).norm()
        loss_iter += loss

    loss = loss_iter / len(dataset_train)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    # Log the loss to wandb
    wandb.log({'step': epoch, 'loss': loss.item()})

    print(f"step {epoch} : loss {loss.item()}")

# Testing at the end of training
loss_iter_test = 0
reconstructed_images = []

with torch.no_grad():
    for images_test in dataset_test:
        images_test = utils.image_to_tensor(images_test).to(device)
        y_test = autoencoder(images_test)
        loss_test = (y_test - images_test).norm()
        loss_iter_test += loss_test

        # # Append the reconstructed images for visualization
        concat = Image.new('RGB', (256, 128))
        concat.paste(utils.tensor_to_image(images_test.data), (0, 0))
        concat.paste(utils.tensor_to_image(y_test.data), (128, 0))
        
        reconstructed_images.append(concat)

images = [PIL.Image.fromarray(np.array(image)) for image in reconstructed_images]

wandb.log({"reconstructed_images": [wandb.Image(image) for image in images]})

loss_test = loss_iter_test / len(dataset_test)

# Log the test loss to wandb
wandb.log({'test_loss': loss_test.item()})

print(f"Test Loss at the end of training: {loss_test.item()}")




0,1
loss,█▆▅▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███

0,1
loss,27.00144
step,188.0


step 0 : loss 145.10202026367188
step 1 : loss 143.31204223632812
step 2 : loss 141.11109924316406
step 3 : loss 137.99388122558594
step 4 : loss 134.60145568847656
step 5 : loss 133.7023162841797
step 6 : loss 130.5761260986328
step 7 : loss 126.23130798339844
step 8 : loss 122.20919799804688
step 9 : loss 116.1917724609375
step 10 : loss 105.69256591796875
step 11 : loss 91.06919860839844
step 12 : loss 81.76445770263672
step 13 : loss 86.25621795654297
step 14 : loss 76.3280258178711
step 15 : loss 72.77108764648438
step 16 : loss 71.07861328125
step 17 : loss 73.41088104248047
step 18 : loss 73.78038787841797
step 19 : loss 70.8030776977539
step 20 : loss 68.08450317382812
step 21 : loss 66.84661102294922
step 22 : loss 66.78334045410156
step 23 : loss 68.0683364868164
step 24 : loss 65.31645965576172
step 25 : loss 63.663856506347656
step 26 : loss 63.80863952636719
step 27 : loss 63.99565887451172
step 28 : loss 63.79499816894531
step 29 : loss 62.354896545410156
step 30 : loss 6