In [None]:
#| eval: false

import warnings
warnings.filterwarnings('ignore')

from fastai.vision.all import *
from fasterbench.benchmark import evaluate_cpu_speed, get_model_size, get_num_parameters

In [None]:
#| eval: false
import torch.nn as nn
import torch
class dfus_block(nn.Module):
    def __init__(self, dim):
        super(dfus_block, self).__init__()
        self.conv1 = nn.Conv2d(dim, 128, 1, 1, 0, bias=False)

        self.conv_up1 = nn.Conv2d(128, 32, 3, 1, 1, bias=False)
        self.conv_up2 = nn.Conv2d(32, 16, 1, 1, 0, bias=False)

        self.conv_down1 = nn.Conv2d(128, 32, 3, 1, 1, bias=False)
        self.conv_down2 = nn.Conv2d(32, 16, 1, 1, 0, bias=False)

        self.conv_fution = nn.Conv2d(96, 32, 1, 1, 0, bias=False)

        #### activation function
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out:[b,c,h,w]
        """
        feat = self.relu(self.conv1(x))
        feat_up1 = self.relu(self.conv_up1(feat))
        feat_up2 = self.relu(self.conv_up2(feat_up1))
        feat_down1 = self.relu(self.conv_down1(feat))
        feat_down2 = self.relu(self.conv_down2(feat_down1))
        feat_fution = torch.cat([feat_up1,feat_up2,feat_down1,feat_down2],dim=1)
        feat_fution = self.relu(self.conv_fution(feat_fution))
        out = torch.cat([x, feat_fution], dim=1)
        return out

class ddfn(nn.Module):
    def __init__(self, dim, num_blocks=78):
        super(ddfn, self).__init__()

        self.conv_up1 = nn.Conv2d(dim, 32, 3, 1, 1, bias=False)
        self.conv_up2 = nn.Conv2d(32, 32, 1, 1, 0, bias=False)

        self.conv_down1 = nn.Conv2d(dim, 32, 3, 1, 1, bias=False)
        self.conv_down2 = nn.Conv2d(32, 32, 1, 1, 0, bias=False)

        dfus_blocks = [dfus_block(dim=128+32*i) for i in range(num_blocks)]
        self.dfus_blocks = nn.Sequential(*dfus_blocks)

        #### activation function
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out:[b,c,h,w]
        """
        feat_up1 = self.relu(self.conv_up1(x))
        feat_up2 = self.relu(self.conv_up2(feat_up1))
        feat_down1 = self.relu(self.conv_down1(x))
        feat_down2 = self.relu(self.conv_down2(feat_down1))
        feat_fution = torch.cat([feat_up1,feat_up2,feat_down1,feat_down2],dim=1)
        out = self.dfus_blocks(feat_fution)
        return out

class HSCNN_Plus(nn.Module):
    def __init__(self, in_channels=3, out_channels=31, num_blocks=30):
        super(HSCNN_Plus, self).__init__()

        self.ddfn = ddfn(dim=in_channels, num_blocks=num_blocks)
        self.conv_out = nn.Conv2d(128+32*num_blocks, out_channels, 1, 1, 0, bias=False)

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out:[b,c,h,w]
        """
        fea = self.ddfn(x)
        out =  self.conv_out(fea)
        return out

In [None]:
#| eval: false

# def get_dls(size, bs):
#     path = URLs.IMAGENETTE_160
#     source = untar_data(path)
#     blocks=(ImageBlock, CategoryBlock)
#     tfms = [RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)]
#     batch_tfms = [Normalize.from_stats(*imagenet_stats)]

#     csv_file = 'noisy_imagenette.csv'
#     inp = pd.read_csv(source/csv_file)
#     dblock = DataBlock(blocks=blocks,
#                splitter=ColSplitter(),
#                get_x=ColReader('path', pref=source),
#                get_y=ColReader(f'noisy_labels_0'),
#                item_tfms=tfms,
#                batch_tfms=batch_tfms)

#     return dblock.dataloaders(inp, path=source, bs=bs)

In [None]:
#| eval: false
# size, bs = 128, 32
# dls = get_dls(size, bs)

In [None]:
#| eval: false
model_path='/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/test_develop_code/model_zoo/hscnn_plus.pth'
data_root= '/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/dataset/'

In [None]:
#| eval: false
# model_path = Path('/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/test_challenge_code/model_zoo/hscnn_plus.pth')

# path = '/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/dataset/Train_RGB/'

In [None]:
#| eval: false
from fastai.vision.all import *
from pathlib import Path
import torch

# Set your dataset path
path = Path('/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/dataset/')
val_path = path / 'Test_RGB'  # Adjust based on your folder structure


In [None]:
#| eval: false
# from fastai.vision.all import *

# Define the path to your dataset
path = Path('/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/dataset/')  # Set this to your validation data folder

# DataBlock for image-to-image tasks
data_block = DataBlock(
    blocks=(ImageBlock, ImageBlock),  # Both input and output are images
    get_items=get_image_files,  # Gets the image files
    get_x=lambda f: PILImage.create(f),  # Use image as input
    get_y=lambda f: PILImage.create(f),  # Use the same image as output
    splitter=RandomSplitter(valid_pct=0.2),  # Split for training/validation (adjust as needed)
    item_tfms=Resize(64),  # Resize transformation, adjust as per your requirement
)

# Create DataLoaders
dls = data_block.dataloaders(path, bs=1)  # Adjust batch size based on memory limits


In [None]:
# #| eval: false
# dls = data_block.dataloaders(val_path, bs=5)  # Use the appropriate batch size
# dls.show_batch()

In [None]:
#| eval: false
# Grab a batch from the training DataLoader
x, y = dls.one_batch()

# Check the shape of inputs and outputs
print("Input (x) shape:", x.shape)
print("Target (y) shape:", y.shape)


Input (x) shape: torch.Size([5, 3, 64, 64])
Target (y) shape: torch.Size([5, 3, 64, 64])


In [None]:
# # | eval: false
# model = HSCNN_Plus()
# checkpoint = torch.load(model_path)
# if 'state_dict' in checkpoint:
#     model.load_state_dict(checkpoint['state_dict'])
# else:
#     model.load_state_dict(checkpoint)
# model.eval()


In [None]:
# print(model)
# print(torch.load(model_path).keys())


In [None]:
#| eval: false
model = HSCNN_Plus()  # Initialize your custom model
# Load model checkpoint
checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
if 'state_dict' in checkpoint:
    model.load_state_dict(checkpoint['state_dict'])
else:
    model.load_state_dict(checkpoint)

model.eval()  # Set to evaluation mode (good practice for inference)


HSCNN_Plus(
  (ddfn): ddfn(
    (conv_up1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (conv_up2): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (conv_down1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (conv_down2): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (dfus_blocks): Sequential(
      (0): dfus_block(
        (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (relu): ReLU(inplace=True)
     

In [None]:
#| eval: false
from torch.nn import MSELoss

# Create the Learner with MSE Loss
learn = Learner(dls, model, loss_func=MSELoss())


In [None]:
#| eval: false
from torch.nn import MSELoss
# Train or fine-tune the model (optional)
model = HSCNN_Plus(in_channels=3, out_channels=3, num_blocks=5)  # Reduce num_blocks significantly
learn = Learner(dls, model.to('cpu'), loss_func=MSELoss())

learn.fit_one_cycle(5, lr_max=1e-4)
# learn.fit_one_cycle(4, 1e-4)
# Run inference on validation set
# preds, targs = learn.get_preds(dl=dls.valid)  # Get predictions


epoch,train_loss,valid_loss,time
0,0.047585,0.043253,00:00
1,0.045987,0.039191,00:00
2,0.043653,0.035421,00:00
3,0.041533,0.033387,00:00
4,0.039822,0.032988,00:00


In [None]:
#| eval: false
# from fastai.callback.all import GradientAccumulation

# # Set gradient accumulation steps to effectively multiply your batch size by this factor
# accumulation_steps = 8  # Adjust based on your needs and memory constraints

# # Create the Learner with gradient accumulation and mixed precision
# learn = Learner(dls, model, loss_func=MSELoss(), cbs=[GradientAccumulation(n_acc=accumulation_steps)]).to_fp16()
# learn.fit_one_cycle(5, lr_max=1e-4)

In [None]:
#| eval: false
# files = get_image_files(path)

# def label_func(f): return f[0].isupper()

# dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(128),bs=32)



In [None]:
#| eval: false
import torch
torch.cuda.empty_cache()


In [None]:
#| eval: false
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


In [None]:
#| eval: false
from torch.nn import MSELoss  # Use Mean Squared Error as loss function for image-to-image tasks

# Define the Learner with MSE loss
learn = Learner(dls, model, loss_func=MSELoss())


In [None]:
#| eval: false
# learn = Learner(dls, model, metrics=[accuracy])

In [None]:
#| eval: false
num_parameters = get_num_parameters(learn.model)
disk_size = get_model_size(learn.model)
print(f"Model Size: {disk_size / 1e6:.2f} MB (disk), {num_parameters} parameters")

Model Size: 2.08 MB (disk), 516640 parameters


In [None]:
#| eval: false
model = learn.model.eval().to('cpu')
x,y = dls.one_batch()

In [None]:
#| eval: false
print(f'Inference Speed: {evaluate_cpu_speed(learn.model, x[0][None])[0]:.2f}ms')

Inference Speed: 15.21ms


In [None]:
#| eval: false
x, y = dls.one_batch()
print("Input Shape:", x.shape)
print("Target Shape:", y.shape)


Input Shape: torch.Size([5, 3, 64, 64])
Target Shape: torch.Size([5, 3, 64, 64])


---

<br>

## **Knowledge Distillation**

<br>

<blockquote>
<pre><b><i> KnowledgeDistillation(teacher.model, loss) </i></b></pre>
<p style="font-size: 15px"><i>
You only need to give to the callback function your teacher learner. Behind the scenes, FasterAI will take care of making your model train using knowledge distillation.
</i></p>
</blockquote>

<br>

In [None]:
#| eval: false
from fasterai.distill.all import *

In [None]:
#| eval: false
import torch

torch.cuda.empty_cache()


In [None]:
#| eval: false
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
#| eval: false
# sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6  # Total trainable parameters in millions


In [None]:
#| eval: false
# import torch

# print(f"Allocated memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
# print(f"Cached memory: {torch.cuda.memory_reserved() / 1024 ** 2:.2f} MB")


In [None]:
#| eval: false
# !nvidia-smi


In [None]:
#| eval: false
# !kill -9 58089      


In [None]:
#| eval: false
from torch.nn import MSELoss
# Train or fine-tune the model (optional)
model = HSCNN_Plus(in_channels=3, out_channels=3, num_blocks=5)  # Reduce num_blocks significantly
teacher = Learner(dls, model.to('cpu'), loss_func=MSELoss())

teacher.fit_one_cycle(10, lr_max=1e-4)
# learn.fit_one_cycle(4, 1e-4)

epoch,train_loss,valid_loss,time
0,0.042094,0.039655,00:00
1,0.041489,0.037667,00:00
2,0.040237,0.034148,00:00
3,0.038269,0.029985,00:00
4,0.035919,0.025685,00:00
5,0.033408,0.021808,00:00
6,0.030831,0.018923,00:00
7,0.028445,0.017224,00:00
8,0.026424,0.01653,00:00
9,0.024772,0.016417,00:00


In [None]:

#| eval: false


from fastai.vision.all import *
from fastai.callback.all import *
from fastai.vision.models.unet import DynamicUnet
from torchvision.models import resnet18

# Step 1: Define the student model with Tiny U-Net structure
# Use only the feature layers (up to the last convolution) of ResNet-18 as the encoder
encoder = nn.Sequential(*list(resnet18(pretrained=True).children())[:-2])  # Remove the last fully connected layers
student_model = DynamicUnet(encoder, n_out=3, img_size=(64, 64))  # Match output channels for your task

# Step 2: Define the Learner for the student model
# Set a suitable loss function for image-to-image tasks like MSELoss
student = Learner(
    dls, 
    student_model, 
    loss_func=MSELoss()#, 
    # metrics=[PSNR()]  # PSNR (Peak Signal-to-Noise Ratio) can be useful for image quality
)

# Step 3: Initialize the KnowledgeDistillationCallback
# Assuming `teacher` is the pre-trained HSCNN_Plus model
kd_cb = KnowledgeDistillationCallback(teacher.model, SoftTarget)

# Step 4: Train the student model with knowledge distillation
student.fit_one_cycle(10, 1e-4, cbs=kd_cb)


epoch,train_loss,valid_loss,time
0,206.853897,4.70828,00:01
1,117.306244,4.507239,00:01
2,78.637085,3.295578,00:01
3,57.473248,2.667634,00:01
4,44.227245,3.118005,00:01
5,35.264557,3.569392,00:01
6,28.856447,3.880894,00:01
7,24.112537,4.008141,00:01
8,20.488949,4.021497,00:01
9,17.672688,4.249494,00:01


In [None]:
#| eval: false
num_parameters = get_num_parameters(student.model)
disk_size = get_model_size(student.model)
print(f"Model Size: {disk_size / 1e6:.2f} MB (disk), {num_parameters} parameters")

Model Size: 124.56 MB (disk), 31113108 parameters


---

<br>

## Quantization

In [None]:
#| eval: false
from fasterai.quantize.quantize_callback import *

In [None]:
#| eval: false
teacher.fit_one_cycle(5, 1e-5, cbs=QuantizeCallback())

epoch,train_loss,valid_loss,time
0,0.018622,0.017505,00:01
1,0.01887,0.017797,00:00
2,0.019117,0.018241,00:00
3,0.0194,0.018464,00:00
4,0.019566,0.01861,00:00


In [None]:
#| eval: false
print(f'Inference Speed: {evaluate_cpu_speed(teacher.model, x[0][None])[0]:.2f}ms')

Inference Speed: 11.60ms


In [None]:
#| eval: false
def count_parameters_quantized(model):
    total_params = 0
    for module in model.modules():
        if isinstance(module, torch.nn.modules.conv.Conv2d) or \
           isinstance(module, torch.nn.Linear) or \
           isinstance(module, torch.ao.nn.quantized.modules.conv.Conv2d) or \
           isinstance(module, torch.ao.nn.quantized.modules.linear.Linear):
            
            total_params += module.weight().numel()
            
            if module.bias() is not None:
                total_params += module.bias().numel()
    return total_params

In [None]:
#| eval: false
num_parameters = count_parameters_quantized(teacher.model)
disk_size = get_model_size(teacher.model)
print(f"Model Size: {disk_size / 1e6:.2f} MB (disk), {num_parameters:,} parameters")

Model Size: 0.59 MB (disk), 514,976 parameters
