In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader, random_split

import torch.nn.functional as F

import tifffile as tiff
import os
import time


In [None]:
# Custom Dataset class
class ImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.tiff')]
        #we don't need to resize into 96*96 because we are doing that in below contrastive transform (self.resize_transform = transforms.resize((96,96)))
    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = tiff.imread(img_path)

        # Ensure the image has 3 layers (channels)
        if image.shape[0] != 3:
            raise ValueError(f"Image {img_path} does not have exactly 3 layers.")
        
        # Normalize the 16-bit image to [0, 1]
        image = image.astype(np.float32) / 65535.0
        
        # Convert to a torch tensor
        image = torch.tensor(image, dtype=torch.float32)
        
        if self.transform:
            image = self.transform(image)
        return image

# Data augmentation similar to the tutorial
contrast_transforms = transforms.Compose([
    #transforms.RandomHorizontalFlip(),
    #transforms.RandomRotation(10),
    transforms.RandomResizedCrop(size=256),
    ])

# Create the dataset
image_dir = r"../../Day10_drugscreened&singledose_untreated"


dataset = ImageDataset(image_dir=image_dir, transform=contrast_transforms)
batch_size = 16

# Function to split dataset with explicit percentage
def split_dataset(dataset, val_percentage):
    val_size = int(len(dataset) * val_percentage)
    train_size = len(dataset) - val_size
    return random_split(dataset, [train_size, val_size])

# Split the dataset with 20% for validation
val_percentage = 0.2
train_dataset, val_dataset = split_dataset(dataset, val_percentage)

# Define DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False, pin_memory=True, num_workers=0) #num_workers=os.cpu count() using cluster gpu
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=0)

In [None]:
for i, image in enumerate(train_loader):
    print(f"Batch {i}:")
    print(f"  image: {image.shape}")
    break

In [None]:
class Autoencod_fituning(nn.Module):
    def __init__(self):
        super(Autoencod_fituning, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding='same'),  # Input: (3, 96, 96) Output: (64, 96, 96)
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2, padding=1),        # Output: (64, 48, 48)
            
            nn.Conv2d(16, 32, kernel_size=3, padding='same'), # Output: (32, 48, 48)
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2, padding=1),        # Output: (32, 24, 24)
            
            nn.Conv2d(32, 64, kernel_size=3, padding='same'), # Output: (16, 24, 24)
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2, padding=0)         # Output: (16, 12, 12)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding='same'), # Output: (32, 12, 12)
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),      # Output: (32, 24, 24)
    
            nn.Conv2d(32, 16, kernel_size=3, padding='same'), # Output: (16, 24, 24)
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),      # Output: (16, 48, 48)
    
            nn.Conv2d(16, 3, kernel_size=3, padding='same'),  # Output: (3, 48, 48)
            nn.Upsample(scale_factor=2, mode='nearest'),      # Output: (3, 96, 96)
            nn.Sigmoid()                                      # Ensures output values are in [0, 1] range
        )
        
    def forward(self, x):
        #print("Input:", x.shape)
        
        # Encoder
        for layer in self.encoder:
            x = layer(x)
            #print(f"After {layer}: {x.shape}")
        
        # Decoder
        for layer in self.decoder:
            x = layer(x)
            #print(f"After {layer}: {x.shape}")
        
        return x


In [None]:
modi = Autoencod_fituning()
criterion = nn.MSELoss()
optimizer = optim.Adam(modi.parameters(), lr=0.001)

In [None]:
def train_and_validate(mo, train_loader, val_loader, optimizer, criterion, num_epochs=1):
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        start_time = time.time()  # Start time for the epoch
        
        # Training phase
        modi.train()
        train_loss = 0
        for images in train_loader:
            optimizer.zero_grad()
            outputs = modi(images)
            loss = criterion(outputs, images)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)
        train_losses.append(train_loss)

        # Validation phase
        modi.eval()
        val_loss = 0
        with torch.no_grad():
            for images in val_loader:
                outputs = modi(images)
                loss = criterion(outputs, images)
                val_loss += loss.item()
        val_loss /= len(val_loader)
        val_losses.append(val_loss)

        end_time = time.time()  # End time for the epoch
        epoch_time = end_time - start_time  # Calculate epoch duration

        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Time: {epoch_time:.2f} seconds')

    return train_losses, val_losses


#train_losses, val_losses = train_and_validate(modi, train_loader, val_loader, optimizer, criterion, num_epochs=1)


In [None]:
train_losses, val_losses = train_and_validate(modi, train_loader, val_loader, optimizer, criterion, num_epochs=10)