In [None]:
import torch
import numpy as np
from dataset import *
from utils import *

batch_size = 2
training_data = dataset("../atrain/color", "../atrain/label", target_transform=target_remap())
validation_data = dataset("../val/color", "../val/label", target_transform=target_remap())
#test_data = dataset("rtest/color", "rtest/label", target_transform=target_remap())

train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(validation_data, batch_size=batch_size, shuffle=True, pin_memory=True, collate_fn=diff_size_collate)
#test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, pin_memory=True,collate_fn=diff_size_collate)

### Old Autoencoder (not tranformer)

In [None]:
import torch.nn as nn

class EncoderPart(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(din, dout, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)   
        )
    
    def forward(self, x):
        return self.encoder(x)

class Encoder(nn.Module):
    def __init__(self, din):
        super().__init__()
        self.encoderPart1 = EncoderPart(din, 64)
        self.encoderPart2 = EncoderPart(64, 32)
        self.encoderPart3 = EncoderPart(32, 16)
    
    def forward(self, x):
        x = self.encoderPart1(x)
        x = self.encoderPart2(x)
        x = self.encoderPart3(x)
        return x

class DecoderPart(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(din, dout, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2)
        )
        
    def forward(self, x):
        return self.encoder(x)

class Decoder(nn.Module):
    def __init__(self, dout):
        super().__init__()
        self.decoderPart1 = DecoderPart(16, 16)
        self.decoderPart2 = DecoderPart(16, 32)
        self.decoderPart3 = DecoderPart(32, 64)
        self.decoderOut = nn.Sequential(
            nn.Conv2d(64, dout, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.decoderPart1(x)
        x = self.decoderPart2(x)
        x = self.decoderPart3(x)
        x = self.decoderOut(x)
        return x

class Autoencoder(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.encoder = Encoder(din)
        self.decoder = Decoder(dout)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
        

In [None]:
"""
model = Autoencoder(3, 3)
image, label = training_data[0]
print(image)
print(image.size())
pred = model(image)
print(pred)
"""

### Transformer Autoencoder

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PatchEmbedding(nn.Module):
    """Converts image into patch embeddings"""
    def __init__(self, img_size=512, patch_size=16, in_channels=3, embed_dim=512):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2  # Number of patches
        self.patch_size = patch_size
        self.patch_dim = in_channels * patch_size * patch_size  # Flattened patch size
        self.projection = nn.Linear(self.patch_dim, embed_dim)  # Linear projection to embedding size
    
    def forward(self, x):
        B, C, H, W = x.shape  # Batch, Channels, Height, Width
        # Convert image (B, C, H, W) into patches:  
        # → (B, num_patches_height, num_patches_width, patch_size, patch_size, C):
        x = x.reshape(B, C, H // self.patch_size, self.patch_size, W // self.patch_size, self.patch_size).permute(0, 2, 4, 3, 5, 1)
        # → (B, num_patches, patch_dim):
        x = x.reshape(B, self.num_patches, -1)  # Flatten patches
        return self.projection(x)  # Project into embedding space



#Uses Transformer Encoder Layers (self-attention + feedforward network).
#Positional Encoding ensures order information is preserved.
class TransformerEncoder(nn.Module):
    def __init__(self, img_size=512, patch_size=16, embed_dim=512, num_heads=4, ff_dim=1024, num_layers=6):
        super().__init__()
        self.encoder_layers = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim),
            num_layers=num_layers
        )
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, embed_dim))  # Positional encoding
    
    def forward(self, x):
        x = x + self.pos_embedding[:, :x.shape[1], :]    # Add position information
        return self.encoder_layers(x)



# Uses cross-attention to refine the latent representation.
"""
class TransformerDecoder(nn.Module):    
    def __init__(self, embed_dim=512, num_heads=4, ff_dim=1024, num_layers=6):
        super().__init__()
        self.decoder_layers = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim),
            num_layers=num_layers
        )
    
    def forward(self, x, memory):
        return self.decoder_layers(x, memory)  # Keep the output in (B, P, embed_dim)
"""


class CNNDecoder(nn.Module):
    """CNN Decoder for reconstructing image from patch embeddings"""
    def __init__(self, img_size=512, embed_dim=512, num_patches=1024, patch_size=16, out_channels=3):
        super().__init__()
        self.num_patches = num_patches
        # Calculate the initial spatial dimensions of the patch grid
        self.patches_per_side = img_size//patch_size   
        self.embed_dim = embed_dim

        # Define the number of upsampling steps 
        # Each step doubles the resolution -> need log2(patch_size) steps.
        num_upsample_steps = int(math.log2(patch_size)) #4

        # Build the decoder layers
        layers = []
        current_channels = embed_dim

        #Start with a Conv layer to potentially adjust channels before upsampling
        #layers.append(nn.Conv2d(embed_dim, current_channels, kernel_size=3, stride=1, padding=1))
        #layers.append(nn.BatchNorm2d(current_channels))
        #layers.append(nn.ReLU()) # TODO: remove

        # Add upsampling blocks
        for i in range(num_upsample_steps):
            out_channels_block = current_channels // 2  # Halve channels at each step (example strategy)
            layers.extend([
                # Use ConvTranspose2d for learnable upsampling
                nn.ConvTranspose2d(current_channels, out_channels_block, kernel_size=2, stride=2),  # upsample image size by 2, reduce channels
                nn.BatchNorm2d(out_channels_block),
                nn.ReLU(),
                # Add another Conv layer for refinement at this scale
                nn.Conv2d(out_channels_block, out_channels_block, kernel_size=3, stride=1, padding=1),  # reduce channels
                nn.BatchNorm2d(out_channels_block),
                nn.ReLU(),
            ])
            current_channels = out_channels_block

        # Final layer to map to the desired output channels (e.g., 3 for RGB)
        layers.append(nn.Conv2d(current_channels, out_channels, kernel_size=3, stride=1, padding=1))

        # Optional: Add Tanh or Sigmoid if you want bounded output pixels
        layers.append(nn.Sigmoid())

        self.decoder = nn.Sequential(*layers)

    def forward(self, x):
        # x shape: (B, P, E) - B = batch, P = num_patches, E = embed_dim
        B, P, E = x.shape
        if P != self.num_patches:
             raise ValueError(f"Input patch count {P} doesn't match expected {self.num_patches}")
        if E != self.embed_dim:
             raise ValueError(f"Input embed dim {E} doesn't match expected {self.embed_dim}")

        # Reshape to spatial grid: (B, P, E) -> (B, H', W', E)
        x = x.view(B, self.patches_per_side, self.patches_per_side, E)
        # Permute to PyTorch format: (B, E, H', W')
        x = x.permute(0, 3, 1, 2).contiguous() # Ensure contiguous memory

        # Expected final shape: (B, self.out_channels, self.patches_per_side, self.patches_per_side)

        # Pass through the CNN decoder layers to get reconstructed image
        return self.decoder(x)  


class ReconstructionTransformerAutoencoder(nn.Module):
    def __init__(self, img_size=512, patch_size=16, in_channels=3, embed_dim=512, num_heads=8, ff_dim=1024, num_layers=6):  #ff_dim=2*embed_dim
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.encoder = TransformerEncoder(img_size, self.patch_embed.patch_size, embed_dim, num_heads, ff_dim, num_layers)
        #self.decoder = TransformerDecoder(embed_dim, num_heads, ff_dim, num_layers)
        #self.fc_out = nn.Linear(embed_dim, self.patch_embed.patch_dim)    # Convert embeddings back to patch values
        self.decoder = CNNDecoder(
            img_size=img_size,
            embed_dim=embed_dim,
            num_patches=self.patch_embed.num_patches,
            patch_size=self.patch_embed.patch_size,
            out_channels=in_channels # Output should have same channels as input
        )
    
    def forward(self, x):
        patches = self.patch_embed(x)
        encoded = self.encoder(patches)
        reconstructed_image = self.decoder(encoded)
        return reconstructed_image
        #decoded = self.decoder(encoded, encoded)
        #reconstructed_patches = self.fc_out(decoded)
        #B, P, D = reconstructed_patches.shape  # (batch, num_patches, patch_dim)
        #img_size = int(P ** 0.5) * self.patch_embed.patch_size
        #return reconstructed_patches.view(B, 3, img_size, img_size)  # Reshape back to image





In [None]:
import torch
import torch.nn as nn

class SegmentationTransformerAutoencoder(nn.Module):
    def __init__(self, img_size=512, patch_size=16, in_channels=3, embed_dim=512, num_heads=8, ff_dim=1024, num_layers=12, num_classes=4):
        super().__init__()
        
        # Patch embedding to break image into patches
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        
        # Transformer encoder for feature extraction
        self.encoder = TransformerEncoder(img_size, self.patch_embed.patch_size, embed_dim, num_heads, ff_dim, num_layers)
        
        # Transformer decoder to output class logits (for segmentation)
        self.decoder = TransformerDecoder(embed_dim, num_heads, ff_dim, num_layers)
        
        # Segmentation head to output logits for 4 classes (background, boundary, cat, dog)
        self.segmentation_head = nn.Conv2d(embed_dim, num_classes, kernel_size=1)  # Output 4 channels

    def forward(self, x):
        #print("[Segmentation] Forward pass")
        # Embed the input image into patches
        patches = self.patch_embed(x)
        encoded = self.encoder(patches)
        decoded = self.decoder(encoded, encoded)  # (B, num_patches, embedding_dim)
   
        # Reshape the decoded output to match image dimensions
        B, P, D = decoded.shape
        img_size = int(P ** 0.5)    # Calculate the original image size: 
                                    # int(P ** 0.5) = sqrt(num_patches) = num patches along each side
        # Reshape the decoded output into a 2D spatial representation
        decoded = decoded.view(B, D, img_size, img_size)  # Reshape to (B, H, W, D) and change to (B, D, H, W) for Conv2d

        # Segmentation head: output logits for 4 classes (background, boundary, cat, dog)
        segmentation_logits = self.segmentation_head(decoded)  # (B, num_classes, H, W)

        # Upsample to original image size
        segmentation_logits_upsampled = F.interpolate(segmentation_logits, size=(512, 512), mode='bilinear', align_corners=False)

        return segmentation_logits_upsampled  # Output segmentation mask (logits for each class)


In [None]:
"""from tqdm import tqdm
import torch
import imgaug.augmenters as iaa

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
    
model_1 = ReconstructionTransformerAutoencoder().to(device)
pytorch_total_params = sum(p.numel() for p in model_1.parameters())
print(f"Total parameters in model_1 (SegmentationTransformerAutoencoder): {pytorch_total_params}")
model_2 = SegmentationTransformerAutoencoder().to(device)
loss_fn = nn.CrossEntropyLoss() # or MSE for reconstruction
learning_rate = 1e-3


losses = []
target_batch_size = 1  #TODO before submission
batch_size = 1          #TODO before submission
to_print = True
for model in [model_2, model_1]:
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    for batch, (X, y) in enumerate(tqdm(train_dataloader, total=len(train_dataloader), desc="Training")):
        
        X = X.to(device)
        y = y.to(device)
        # Compute prediction
        pred = model(X)
        print(f"pred shape: {pred.shape}")
        print(f"y shape: {y.shape}")    #comment for reconstruction
        print(f"X shape: {X.shape}")
        # Compute loss
        #loss = loss_fn(pred, X) # for reconstruction
        loss = loss_fn(pred, y.squeeze(1)) #for segmentation
        losses.append(loss.item())
        
        loss.backward()
        
        # Ensure gradients are reset to 0 for new batch
        optimizer.step()
        optimizer.zero_grad()
           
        plt.imshow(X[0].permute(1,2,0).cpu().detach().numpy())
        plt.show()
        plt.imshow(pred[0].permute(1,2,0).cpu().detach().numpy())
        plt.show()
    print(f"losses: {losses}")


    

"""

### Training

In [None]:
def trainReconstruction(dataloader, model, loss_fn, optimizer):
    losses = []
    model.train()
    target_batch_size = 64  #TODO before submission
    #batch_size = 16          #TODO before submission
    to_print = True
    for batch, (X, _) in enumerate(tqdm(dataloader, total=len(dataloader), desc="Training")):
        X = X.to(device)
        # Compute prediction
        pred = model(X)
        
        # Compute loss
        loss = loss_fn(pred, X)
        losses.append(loss.item())
        
        loss.backward()
        
        if batch % (target_batch_size/batch_size) == 0:
            # Ensure gradients are reset to 0 for new batch
            optimizer.step()
            optimizer.zero_grad()
            # if to_print:
            #     print(f"memory: {torch.cuda.device_memory_used()}")
            #     to_print = False
        
    return np.mean(losses)
        
    

In [None]:
def trainSegmentation(dataloader, model, loss_fn, optimizer):
    losses = []
    model.train()
    target_batch_size = 64  #TODO before submission
    batch_size = 16          #TODO before submission
    to_print = True
    for batch, (X, y) in enumerate(tqdm(dataloader, total=len(dataloader), desc="Training")):
        X = X.to(device)
        y = y.to(device)
        # Compute prediction
        pred = model(X)
        
        # Compute loss
        loss = loss_fn(pred, y.squeeze(1))
        losses.append(loss.item())
        
        loss.backward()
        
        if batch % (target_batch_size/batch_size) == 0:
            # Ensure gradients are reset to 0 for new batch
            optimizer.step()
            optimizer.zero_grad()
            # if to_print:
            #     print(f"memory: {torch.cuda.device_memory_used()}")
            #     to_print = False
        
    return np.mean(losses)
        
    

### Evaluation

In [None]:
target_size = 512
interpolation = 'bilinear'

def evalReconstruction(dataloader, model, loss_fn):
    #print("Evaluating reconstruction")
    model.eval()
    num_batches = len(dataloader)
    total_loss = 0.0
    losses = []
    with torch.no_grad():
        for batch, (original_X, _) in enumerate(tqdm(dataloader, total=len(dataloader), desc="Evaluation")):
            resized_X, meta_list = process_batch_forward(original_X, target_size=target_size)   # resize X for network
            resized_X = resized_X.to(device)
            
            # Compute prediction
            pred = model(resized_X)

            pred = process_batch_reverse(pred, meta_list, interpolation=interpolation)

            for p, label in zip(pred, original_X):
                # Move individual prediction and label to the device
                p = p.to(device).unsqueeze(0)  # Add batch dimension
                label = label.to(device).unsqueeze(0)  # Add batch dimension and ensure type is long

                if label.shape[1] == 4 and label.ndim == 4:  
                    print(f"    Converting original image from RGBA to RGB")
                    label = label[:, :3, :, :] # Keep only the first 3 channels (R, G, B)

                #print(f"p shape: {p.shape}")
                #print(f"label shape: {label.shape}")
                # print(p.size(), flush=True)
                # print(label.size())
                # Calculate the loss for the current pair
                loss = loss_fn(p, label.squeeze(1))
                total_loss += loss.item()
                # Loss list
                losses.append(loss.item())
    
    return total_loss / num_batches, np.mean(losses)  

In [None]:
target_size = 512
interpolation = 'bilinear'

def evalSegmentation(dataloader, model, loss_fn):
    model.eval()
    num_batches = len(dataloader)
    total_loss = 0.0
    with torch.no_grad():
        for batch, (X, y) in enumerate(tqdm(dataloader, total=len(dataloader), desc="Evaluation")):
            X, meta_list = process_batch_forward(X, target_size=target_size)
            X = X.to(device)
            y = y.to(device)

            if X.shape[1] == 4 and X.ndim == 4:  
                #print(f"    Converting RGBA image to RGB image")
                X = X[:, :3, :, :] # Keep only the first 3 channels (R, G, B)

            # Compute prediction
            pred = model(X)

            pred = process_batch_reverse(pred, meta_list, interpolation=interpolation)

            for p, label in zip(pred, y):
                
                # Move individual prediction and label to the device
                p = p.to(device).unsqueeze(0)  # Add batch dimension
                label = label.to(device).unsqueeze(0)  # Add batch dimension and ensure type is long
                
                # print(p.size(), flush=True)
                # print(label.size())
                # Calculate the loss for the current pair
                loss = loss_fn(p, label.squeeze(1))
                total_loss += loss.item()
            # Compute loss
            # loss = loss_fn(pred, X)
            # losses.append(loss.item())
    
    return total_loss / num_batches

### Run Reconstruction Transformer Autoencoder

In [None]:
from tqdm import tqdm
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
    
model = ReconstructionTransformerAutoencoder(img_size=512, patch_size=16, in_channels=3, embed_dim=512, num_heads=8, ff_dim=1024, num_layers=6).to(device)
loss_fn = nn.MSELoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
MODEL_SAVE_DIR = "reconstruction/autoencoder"
MODEL_NAME = "checkpoint.pytorch"
start_epoch = 0
getPastCheckpoint = False

if getPastCheckpoint and os.path.isfile(f"{MODEL_SAVE_DIR}/{MODEL_NAME}"):
    print(f"Loading checkpoint from: {MODEL_SAVE_DIR}/{MODEL_NAME}")
    # Load the checkpoint dictionary; move tensors to the correct device
    checkpoint = torch.load(f"{MODEL_SAVE_DIR}/{MODEL_NAME}", map_location=device)

    # Load model state
    model.load_state_dict(checkpoint["model_state_dict"])
    print(" -> Model state loaded.")

    # Load optimizer state
    try:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        print(" -> Optimizer state loaded.")
    except Exception as e:
        print(f" -> Warning: Could not load optimizer state: {e}. Optimizer will start from scratch.")

    # Load scheduler state
    # try:
    #     scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    #     print(" -> Scheduler state loaded.")
    # except Exception as e:
    #     print(f" -> Warning: Could not load scheduler state: {e}. Scheduler will start from scratch.")


    # Load training metadata
    start_epoch = checkpoint.get("epoch", 0) # Load last completed epoch, training continues from next one
    # best_dev_dice = checkpoint.get("best_dev_dice", -np.inf)
    # best_dev_miou = checkpoint.get("best_dev_miou", -np.inf)
    best_val_loss = checkpoint.get("best_val_loss", np.inf)

    print(f" -> Resuming training from epoch {start_epoch + 1}")
    # print(f" -> Loaded best metrics: Dice={best_dev_dice:.6f}, mIoU={best_dev_miou:.6f}, Loss={best_dev_loss:.6f}")
    loaded_notes = checkpoint.get("notes", "N/A")
    print(f" -> Notes from checkpoint: {loaded_notes}")

else:
    print(f"Checkpoint file not found at {MODEL_SAVE_DIR}/{MODEL_NAME}. Starting training from scratch.")



best_val_loss = np.inf 
EPOCHS = 100
print("\nStarting Training (Transformer Autoencoder)...")
for t in range(start_epoch, EPOCHS):
    current_epoch = t + 1
    print(f"Epoch {t+1}\n-------------------------------")
    train_loss = trainReconstruction(train_dataloader, model, loss_fn, optimizer)

    val_loss, debug_val_loss = evalReconstruction(val_dataloader, model, loss_fn)

    # Save model based on validation val loss improvement
    if val_loss < best_val_loss:
        print(f"Validation loss improved ({best_val_loss:.6f} → {val_loss:.6f}). Saving model...")
        # best_dev_dice = val_dice_micro
        # best_dev_miou = val_miou # Save corresponding mIoU
        best_val_loss = val_loss # Save corresponding loss
        # print(f"Validation Micro Dice score improved ({best_dev_dice:.6f}). Saving model...")
        checkpoint_path = os.path.join(MODEL_SAVE_DIR, f"{MODEL_NAME}") # Changed name
        checkpoint = {
            "epoch": t + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            # "scheduler_state_dict": scheduler.state_dict(),
            # "best_dev_dice": best_dev_dice,
            # "best_dev_miou": best_dev_miou,
            "best_val_loss": best_val_loss,
            # "notes": f"Model saved based on best Micro Dice. Ignored index for metric: {EVAL_IGNORE_INDEX}"
        }
        torch.save(checkpoint, checkpoint_path)

    else:
    #     print(f"Validation Micro Dice score did not improve from {best_dev_dice:.6f}")
        print(f"Corresponding validation loss: {best_val_loss:.6f}")
    
    print(f"Debug Validation loss: {debug_val_loss:.6f}")
    print(f"Train loss: {train_loss:.6f}")

# PLot a training image reconstruction
img, label = training_data[0]
img = img.to(device)
res = model(img.unsqueeze(0))
plt.imshow(res[0].permute(1,2,0).cpu().detach().numpy())
plt.savefig(f"reconstruction/images/test{t}.png", format="png")
plt.show()
    
print("\n--- Training Finished! ---")
# print(f"Best validation Micro Dice score achieved: {best_dev_dice:.6f}")
# print(f"Corresponding validation mIoU: {best_dev_miou:.6f}")
print(f"Best model saved to: {os.path.join(MODEL_SAVE_DIR, f'{MODEL_NAME}')}")


### Run Segmentation Transformer Autoencoder

In [None]:
"""
from tqdm import tqdm
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
#for param in model.encoder.parameters():
    #param.requires_grad = False
model = SegmentationTransformerAutoencoder(img_size=512, patch_size=16, in_channels=3, embed_dim=512, num_heads=8, ff_dim=1024, num_layers=6).to(device)
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

best_eval_loss = np.inf 

for epoch in range(100):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = trainSegmentation(train_dataloader, model, loss_fn, optimizer)
    eval_loss = evalSegmentation(val_dataloader, model, loss_fn)
    print(f"Eval Loss: {eval_loss} for epoch {epoch}")
    with open("test.txt", "a") as file:
        file.write(f"Eval Loss: {eval_loss} for epoch {epoch}\n")
        
    if eval_loss <= best_eval_loss:
        best_eval_loss = eval_loss
        checkpoint = {"model": model.state_dict(),
              "optimizer": optimizer.state_dict()}
        torch.save(checkpoint, f"autoencoder/checkpoint_{epoch}.pytorch")
        img, label = training_data[0]
        img = img.to(device)

        res = model(img.unsqueeze(0))
        plt.imshow(res[0].permute(1,2,0).cpu().detach().numpy())
        plt.savefig(f"test{epoch}.png", format="png")
        plt.show()

    

"""

In [None]:
"""
import matplotlib.pyplot as plt

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

checkpoint = torch.load("autoencoder/checkpoint_20.pytorch")
model.load_state_dict(checkpoint["model"])
model.to(device)

img, label = training_data[0]
img = img.to(device)

res = model(img.unsqueeze(0))
print(res.size())
plt.imshow(res[0].permute(1,2,0).cpu().detach().numpy())
plt.show
"""