In [None]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid

!pip install livelossplot
from livelossplot import PlotLosses

# If on drive retrieve the images
try:
    from google.colab import drive
    drive.mount('/content/drive')
    data_path = '/content/drive/MyDrive/Fiber Finder/generative/images.zip'
    local_path = '/content/'
    !cp -r "$data_path" "$local_path"
    !unzip "/content/images.zip" -d "/content" > /dev/null 2>&1
    print('Data recieved and unzipped!')

except Exception as e:
    print('No Drive or Error:', e)

data_path = './images'
data_path = '../images'

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled   = False

    return True

device = 'cpu'
if torch.cuda.device_count() > 0 and torch.cuda.is_available():
    print("Cuda installed! Running on GPU!")
    device = 'cuda'
else:
    print("No GPU available!")

In [None]:
class FiberDataset(Dataset):
    """Creates a dataset from an image file."""
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.files = os.listdir(image_dir)
        self.image_paths = []

        self.labels = {}
        self.false = []

        self._parse_files()

    def _parse_files(self):
        for file_name in self.files:
            file_name = os.path.join(self.image_dir, file_name)
            if file_name.endswith(".png"):
                self.image_paths.append(file_name)
            elif file_name.endswith(".csv"):
                labels = pd.read_csv(file_name, header=None)[0]
                labels.index += 1
                labels.index += 2000
                labels = dict(labels)
                self.labels = labels
            else:
                self.false.append(file_name)
        pass

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        file_name = self.image_paths[idx]
        image_no = int(file_name.split("/")[-1].replace(".png", ""))
        image = Image.open(file_name).convert('L')

        label = self.labels[image_no]

        if self.transform:
            image = self.transform(image)

        return image, label


transform = transforms.Compose([transforms.ToTensor()])

data_path_bio = data_path + "/bio/"
data_path_bio = data_path + "/diffusion/diffusion_voxels/"

dataset = FiberDataset(data_path_bio, transform=transform)
exp_loader = DataLoader(dataset, batch_size=1, shuffle=False)


In [None]:
def show_images_grid(images, title=None, nrow=8):
    """Plots the images in a grid"""
    img_grid = make_grid(images, nrow=nrow).numpy()
    plt.imshow(np.transpose(img_grid, (1, 2, 0)))
    plt.title(title)
    plt.axis('off')
    plt.show()

# DataLoader for the dataset
exp_loader = DataLoader(dataset, shuffle=True)

# Thresholds
low_threshold = 0.4
high_threshold = 0.65

# Lists to store samples
very_low_mean_samples = []
medium_mean_samples = []
high_mean_samples = []

for images, labels in exp_loader:
    for img in images:
        mean_intensity = img.mean()
        if mean_intensity <= low_threshold and len(very_low_mean_samples) < 32:
            very_low_mean_samples.append(img)
        elif low_threshold < mean_intensity <= high_threshold and len(medium_mean_samples) < 32:
            medium_mean_samples.append(img)
        elif mean_intensity > high_threshold and len(high_mean_samples) < 32:
            high_mean_samples.append(img)

    if len(very_low_mean_samples) >= 32 and len(medium_mean_samples) >= 32 and len(high_mean_samples) >= 32:
        break

show_images_grid(very_low_mean_samples, "Very Low Mean Pixel Intensity Images")
show_images_grid(medium_mean_samples, "Medium Mean Pixel Intensity Images")
show_images_grid(high_mean_samples, "High Mean Pixel Intensity Images")

In [None]:
class FilteredDataset(Dataset):
    """Creates a filtered set from the standard dataset using tresholds."""
    def __init__(self, original_dataset, low_threshold=0., high_threshold=1, transform=None):
        self.transform = transform
        self.data = [(img, label) for img, label in original_dataset if low_threshold < img.mean() <= high_threshold]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx][0]
        label = self.data[idx][1]

        if self.transform:
            image = self.transform(image)

        return image, label


transform = transforms.Compose([transforms.ToTensor()])

dataset = FiberDataset(data_path_bio, transform=transform)
filtered_dataset = FilteredDataset(dataset, low_threshold=.1, high_threshold=.65)
filtered_loader = DataLoader(filtered_dataset, batch_size=1, shuffle=False)
filtered_samples = []

for i, data in enumerate(filtered_dataset):
    image = data[0]
    filtered_samples.append(image)
    if i >= 31:
        break

show_images_grid(filtered_samples, "Sample Images from Filtered Dataset")


In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.RandomHorizontalFlip(),
                                transforms.RandomVerticalFlip()])

dataset = FiberDataset(data_path_bio, transform=transform)
filtered_dataset = FilteredDataset(dataset, low_threshold=.1, high_threshold=.65)
loader = DataLoader(filtered_dataset, batch_size=64, shuffle=False)

In [49]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        # Define a simple U-Net architecture
        # You should replace this with a more complex and effective U-Net design
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

class DiffusionModel(nn.Module):
    def __init__(self, in_channels, num_diffusion_steps):
        super(DiffusionModel, self).__init__()
        self.unet = UNet(1, in_channels)
        self.num_diffusion_steps = num_diffusion_steps

    def forward(self, x, t):
        # x: Input data (e.g., an image)
        # t: Time step in the diffusion process
        noise = torch.randn_like(x)
        noisy_data = self.q_sample(x, t, noise)
        predicted_noise = self.unet(noisy_data)
        return predicted_noise

    def q_sample(model, x, t, noise):
        # Apply noise
        noisy_x = x + noise
        # Pass through the model
        predicted_noise = model(noisy_x, t)
        return noisy_x, predicted_noise

    def p_sample(self, x_t, t, predicted_noise):
        # Denoising step using predicted noise
        # This should be defined according to your noise schedule and model design.
        return x_t - predicted_noise



In [50]:
def train_diffusion_model(model, data_loader, epochs, model_path, save=False):
    # Optimizer
    lr = 1e-5
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Loss function (modify according to your diffusion model's requirement)
    loss_function = nn.MSELoss()

    groups = {'Loss': ['total_loss']}
    liveloss = PlotLosses(groups=groups)

    for epoch in range(epochs):
        logs = {}
        total_loss = 0

        for real_data, _ in data_loader:
            batch_size = real_data.size(0)
            real_data = real_data.to(device)

            # Assuming your diffusion model has a method to generate noise and denoise
            t = torch.randint(0, model.num_diffusion_steps, (batch_size, 1)).to(device)
            noise = torch.randn_like(real_data)
            # Call q_sample with both real_data, noise, and time step t
            noisy_data = model.q_sample(real_data, t, noise)
            predicted_noise = model(noisy_data, t)

            optimizer.zero_grad()

            # Calculate loss - example given for an MSE loss
            loss = loss_function(predicted_noise, noise)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # Average losses for current epoch
        logs['total_loss'] = total_loss / len(data_loader)

        # Update livelossplot
        liveloss.update(logs)
        liveloss.send()

        # Save model parameters every 'save_interval' epochs
        if (epoch + 1) % 10 == 0 and save:
            model_path = os.path.join(model_path, f'diffusion_model_epoch_{epoch + 1}.pth')
            torch.save(model.state_dict(), model_path)

    return model

in_channels = 50 * 50
num_diffusion_steps = 1000
model = DiffusionModel(in_channels, num_diffusion_steps).to(device)

# Training the Diffusion Model
epochs = 1500
model_path = './DiffusionModel'
trained_model = train_diffusion_model(model, loader, epochs, model_path, save=True)

RecursionError: maximum recursion depth exceeded

In [None]:
def plot_samples(model, num_samples=32, nrow=8, save=False, model_type='VAE'):
    save_dir=f'./{model_type}_fiber'
    # Generate random latent vectors
    z = torch.randn(num_samples, model.fc[0].in_features).to(device)

    # Decode the latent vectors
    model.eval()
    with torch.no_grad():
        imgs = model(z).view(-1, 1, 50, 50).to('cpu')
    # Create a grid of images and display
    img_grid = make_grid(imgs, nrow=nrow)
    plt.figure(figsize=(10, 10))
    plt.imshow(np.transpose(img_grid.numpy(), (1, 2, 0)), cmap='gray')
    plt.axis('off')
    plt.show()

    # Save the images if requested
    if save:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        for i, img in enumerate(imgs):
            fname = f'{str(i+1).zfill(3)}.jpeg'
            img_np = img.squeeze().numpy()
            if img_np.max() <= 1:
                img_np = (img_np * 255).astype(np.uint8)
            img_pil = Image.fromarray(img_np)
            img_pil.save(os.path.join(save_dir, fname))

trained_generator = gan[0]

plot_samples(trained_generator, num_samples=32, nrow=8, save=False, model_type='GAN')