# DeepTransGAN: Training and Prediction Notebook
This notebook provides a comprehensive guide to training and using a neural network for converting low-light (INR) images to normal-light (RGB) images. The project is inspired by AI enhancement techniques and aims to improve visibility in low-light conditions for applications such as autonomous driving and surveillance.

## 1. Dataset Preparation

The dataset consists of pairs of INR images (low-light) and corresponding RGB images (normal light). The `datasets/data_set.py` script is used to load and preprocess the data.

### 1.1. Dataset Format

The dataset should be organized as follows:

```
your_data_directory/
    our485/  # Training data
        high/  # Contains RGB images
            image1.png
            image2.jpg
            ...
        low/   # Contains INR images
            image1.png
            image2.jpg
            ...
    eval15/  # Testing data
        high/
            image1.png
            image2.jpg
            ...
        low/
            image1.png
            image2.jpg
            ...
```

### 1.2. Loading and Visualizing Data

Here's an example of how to use `datasets/data_set.py` to load and visualize the data:

In [None]:
import os
import cv2
import matplotlib.pyplot as plt
from torchvision import transforms
from datasets.data_set import LowLightDataset

# 1. Set the dataset directory
data_dir = "../datasets/LOLdataset"  # Replace with your dataset path

# 2. Define image transformations
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 3. Create dataset instance
train_dataset = LowLightDataset(image_dir=data_dir, transform=transform, phase="train")

# 4. Access a sample
low_img, high_img = train_dataset[0]

# 5. Convert tensors to numpy arrays and rescale
low_img_np = low_img.permute(1, 2, 0).numpy() * 0.5 + 0.5
high_img_np = high_img.permute(1, 2, 0).numpy() * 0.5 + 0.5

# 6. Display the images
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(low_img_np)
axes[0].set_title("Low-light (INR) Image")
axes[1].imshow(high_img_np)
axes[1].set_title("Normal-light (RGB) Image")
plt.show()

## 2. Model Definition

The core of the INR2RGB conversion is a Generator model, which attempts to transform the low-light INR image into a corresponding RGB image. A Discriminator (or Critic in WGAN) model is used to evaluate the quality of the generated images and provide feedback to the Generator during training.

The models are defined in `models/base_mode.py`. The Generator uses RepViT blocks, SPPELAN, and other convolutional layers to extract features and generate the RGB image. The Discriminator (or Critic) uses Disconv layers to classify images as real or fake.

### 2.1. Generator

The Generator architecture consists of several convolutional blocks, upsampling layers, and concatenation operations. It takes an INR image as input and outputs an RGB image.

### 2.2. Discriminator (or Critic)

The Discriminator (or Critic) is a binary classifier that distinguishes between real RGB images and generated RGB images. It provides feedback to the Generator, guiding it to produce more realistic images.

In [None]:
import torch
from models.base_mode import Generator, Discriminator, Critic

# Example instantiation
generator = Generator()
discriminator = Discriminator()
# For WGAN training, use Critic instead of Discriminator
# critic = Critic()

# Print model architecture
print("Generator architecture:")
print(generator)
print("\nDiscriminator architecture:")
print(discriminator)
# print("\nCritic architecture:")
# print(critic)

## 3. Training

The `train.py` script is used to train the INR2RGB model. It defines the training loop, loss functions, optimizers, and other training parameters.

### 3.1. Training Process

The training process involves the following steps:

1.  **Load the dataset**: The `LowLightDataset` class is used to load the INR and RGB image pairs.
2.  **Define the models**: The Generator and Discriminator (or Critic) models are instantiated.
3.  **Define the loss functions**: The BCEBlurWithLogitsLoss (or MSELoss) is used for the Generator, and BCEBlurWithLogitsLoss is used for the Discriminator.
4.  **Define the optimizers**: Adam is used to optimize the Generator and Discriminator (or Critic) parameters.
5.  **Iterate over the dataset**: For each batch of images, the following steps are performed:
    *   Generate fake RGB images from the INR images using the Generator.
    *   Train the Discriminator (or Critic) to distinguish between real and fake images.
    *   Train the Generator to produce more realistic images that can fool the Discriminator (or Critic).
6.  **Evaluate the model**: After each epoch, the model is evaluated on a validation set using metrics such as PSNR and SSIM.

### 3.2. Training Script Usage

To train the model, run the following command:

```bash
!python train.py --data your_data_directory --epochs 100 --loss mse --batch_size 8
```

Replace `your_data_directory` with the path to your dataset directory. You can also adjust the other training parameters as needed.

For WGAN training, use the `--wgan` flag:

```bash
!python train.py --data your_data_directory --epochs 100 --loss mse --batch_size 8 --wgan True
```

### 3.3. TensorBoard Visualization

You can use TensorBoard to visualize the training process. To start TensorBoard, run the following command:

```bash
!tensorboard --logdir runs/
```

Then, open your browser and navigate to `http://localhost:6006` to view the TensorBoard dashboard.

In [None]:
import os
import torch
import torch.optim as optim
from torchvision import transforms
from datasets.data_set import LowLightDataset
from models.base_mode import Generator, Discriminator
from utils.loss import BCEBlurWithLogitsLoss
from torch.utils.data import DataLoader

# 1. Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. Define the data directory and other parameters
data_dir = "../datasets/LOLdataset"  # Replace with your dataset path
batch_size = 8
img_size = (256, 256)
num_epochs = 2
learning_rate = 0.0002

# 3. Define image transformations
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 4. Create the dataset and data loader
train_dataset = LowLightDataset(image_dir=data_dir, transform=transform, phase="train")
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 5. Instantiate the models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 6. Define the loss functions
g_loss_fn = BCEBlurWithLogitsLoss().to(device)
d_loss_fn = BCEBlurWithLogitsLoss().to(device)

# 7. Define the optimizers
g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)
d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)

# 8. Training loop
for epoch in range(num_epochs):
    for i, (low_images, high_images) in enumerate(train_loader):
        low_images = low_images.to(device)
        high_images = high_images.to(device)

        # Train the discriminator
        d_optimizer.zero_grad()
        fake_images = generator(low_images)
        real_labels = torch.ones(low_images.size(0)).to(device)
        fake_labels = torch.zeros(low_images.size(0)).to(device)

        d_real_loss = d_loss_fn(discriminator(high_images).squeeze(), real_labels)
        d_fake_loss = d_loss_fn(discriminator(fake_images.detach()).squeeze(), fake_labels)
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        d_optimizer.step()

        # Train the generator
        g_optimizer.zero_grad()
        g_loss = g_loss_fn(discriminator(fake_images).squeeze(), real_labels)
        g_loss.backward()
        g_optimizer.step()

        if i % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}")

# 9. Save the trained models
os.makedirs("runs/generator", exist_ok=True)
os.makedirs("runs/discriminator", exist_ok=True)
torch.save(generator.state_dict(), "runs/generator/generator.pth")
torch.save(discriminator.state_dict(), "runs/discriminator/discriminator.pth")

## 4. Prediction

The `predict.py` script is used to generate RGB images from INR images using the trained Generator model.

### 4.1. Prediction Script Usage

To perform prediction on a single image or a directory of images, run the following command:

```bash
!python predict.py --data path_to_image_or_directory --model runs/generator/generator.pth
```

Replace `path_to_image_or_directory` with the path to the INR image or directory containing INR images. Replace `runs/generator/generator.pth` with the path to the trained Generator model.

To use the model with a live camera feed, run the following command:

```bash
!python predict.py --data 0 --model runs/generator/generator.pth
```

This will open a window displaying the live camera feed and the generated RGB images.

In [None]:
import os
import cv2
import torch
from torchvision import transforms
from models.base_mode import Generator
import numpy as np
import matplotlib.pyplot as plt

# 1. Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. Define the model path and image path
model_path = "runs/generator/generator.pth"  # Replace with your model path
image_path = "../datasets/LOLdataset/eval15/low/1.png"  # Replace with your image path

# 3. Define image transformations
img_size = (256, 256)
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 4. Load the model
generator = Generator().to(device)
generator.load_state_dict(torch.load(model_path))
generator.eval()

# 5. Load the image
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_tensor = transform(img).unsqueeze(0).to(device)

# 6. Perform prediction
with torch.no_grad():
    generated_img = generator(img_tensor)

# 7. Convert the generated image to a numpy array and rescale
generated_img_np = generated_img.squeeze().permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5

# 8. Display the original and generated images
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(img)
axes[0].set_title("Original INR Image")
axes[1].imshow(generated_img_np)
axes[1].set_title("Generated RGB Image")
plt.show()

## 5. Evaluation

The performance of the trained model can be evaluated using metrics such as Peak Signal-to-Noise Ratio (PSNR) and Structural Similarity Index Measure (SSIM).

### 5.1. PSNR

PSNR measures the quality of the generated image compared to the ground truth RGB image. Higher PSNR values indicate better image quality.

### 5.2. SSIM

SSIM measures the structural similarity between the generated image and the ground truth RGB image. SSIM values range from -1 to 1, with higher values indicating better similarity.

In [None]:
import cv2
import torch
from torchvision import transforms
from datasets.data_set import LowLightDataset
from models.base_mode import Generator
from torcheval.metrics.functional import peak_signal_noise_ratio
from utils.misic import ssim
from torch.utils.data import DataLoader

# 1. Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. Define the model path and data directory
model_path = "runs/generator/generator.pth"  # Replace with your model path
data_dir = "../datasets/LOLdataset"  # Replace with your dataset path

# 3. Define image transformations
img_size = (256, 256)
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 4. Create the test dataset and data loader
test_dataset = LowLightDataset(image_dir=data_dir, transform=transform, phase="test")
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# 5. Load the model
generator = Generator().to(device)
generator.load_state_dict(torch.load(model_path))
generator.eval()

# 6. Evaluate the model
psnr_values = []
ssim_values = []

with torch.no_grad():
    for low_images, high_images in test_loader:
        low_images = low_images.to(device)
        high_images = high_images.to(device)

        fake_images = generator(low_images)

        psnr = peak_signal_noise_ratio(fake_images, high_images).item()
        ssim_val = ssim(fake_images, high_images).item()

        psnr_values.append(psnr)
        ssim_values.append(ssim_val)

# 7. Print the results
print(f"PSNR: {sum(psnr_values) / len(psnr_values):.4f}")
print(f"SSIM: {sum(ssim_values) / len(ssim_values):.4f}")

## 6. Conclusion

This notebook provides a step-by-step guide to training and using a neural network for converting INR images to RGB images. By following the instructions in this notebook, you can train your own INR2RGB model and use it to enhance night vision capabilities for various applications.

### 6.1. Future Directions

*   Experiment with different model architectures and loss functions.
*   Train the model on larger and more diverse datasets.
*   Explore the use of transfer learning to improve the performance of the model.
*   Implement real-time INR2RGB conversion for live video feeds.