In [1]:
!pip install guided-filter-pytorch

Collecting guided-filter-pytorch
  Downloading guided_filter_pytorch-3.7.5-py3-none-any.whl.metadata (1.6 kB)
Downloading guided_filter_pytorch-3.7.5-py3-none-any.whl (3.8 kB)
Installing collected packages: guided-filter-pytorch
Successfully installed guided-filter-pytorch-3.7.5


In [2]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.cuda as cuda
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch import nn
from torchvision.models.alexnet import AlexNet_Weights
import torchvision
import numpy as np
import cv2
from guided_filter_pytorch.guided_filter import GuidedFilter
import time
from tqdm import tqdm

**Creating the functions for extracting the high and low freqency image components**

In [3]:
def createLowFrequencyComponent(img, guided_filter_Radius = 10):

    image = cv2.imread(img)
    grayscale_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    img_tensor = torch.from_numpy(image).float().permute(2, 0, 1).unsqueeze(0) / 255.0
    gray_tensor = torch.from_numpy(grayscale_image).float().unsqueeze(0).unsqueeze(0) / 255.0

    # Use the already defined hr_x (GuidedFilter instance)
    GF = GuidedFilter(r=guided_filter_Radius, eps=0.01)

    low_freq_image = GF(gray_tensor, img_tensor)
    low_freq_image = low_freq_image.squeeze(0).permute(1, 2, 0)    ## convert tensor to proper image dimensions
    low_freq_image = low_freq_image.numpy()     ## convert tensor to numpy array

    return low_freq_image

def createHighFrequencyComponent(img, epsilon=0.01):

    image = cv2.imread(img)
    eps = np.full((1200, 1600, 3), epsilon)     ## for numerical stability
    eps_tensor = torch.from_numpy(eps).float().permute(0, 1, 2)     ## convert eps to tensor

    # create the low frequency image
    low_freq_image = createLowFrequencyComponent(img)
    low_freq_image = torch.from_numpy(low_freq_image)

    # create the high frequency image
    high_frequency_image = image/(low_freq_image + eps_tensor)
    Ih_yuv = cv2.cvtColor(high_frequency_image.detach().numpy(), cv2.COLOR_RGB2YUV)
    Y = Ih_yuv[:, :, 0]
    high_frequency_image = (Y - Y.min()) / (Y.max() - Y.min())

    return high_frequency_image

**Creating the first stream**

In [4]:
class FirstStream(nn.Module):
    def __init__(self):
        super(FirstStream, self).__init__()

        # Load pretrained AlexNet
        alexnet = torchvision.models.alexnet(pretrained=AlexNet_Weights.DEFAULT)

        # Use AlexNet features (conv1 to conv5)
        self.features = alexnet.features  # Conv layers

        # Use AlexNet fc6 and fc7
        self.fc6 = alexnet.classifier[0]  # Linear(9216, 4096)
        self.relu6 = alexnet.classifier[1]
        self.dropout6 = alexnet.classifier[2]

        self.fc7 = alexnet.classifier[3]  # Linear(4096, 4096)
        self.relu7 = alexnet.classifier[4]
        self.dropout7 = alexnet.classifier[5]

        # Custom fc8 and fc9 layers
        self.fc8 = nn.Linear(4096, 2048)
        self.relu8 = nn.ReLU()
        self.dropout8 = nn.Dropout(p=0.5)

        self.fc9 = nn.Linear(2048, 531)

    def forward(self, x):
        x = self.features(x)              # conv1–conv5
        x = torch.flatten(x, 1)           # Flatten to (B, 9216)

        x = self.fc6(x)
        x = self.relu6(x)
        x = self.dropout6(x)

        x = self.fc7(x)
        x = self.relu7(x)
        x = self.dropout7(x)

        x = self.fc8(x)
        x = self.relu8(x)
        x = self.dropout8(x)

        x = self.fc9(x)
        
        return x

In [5]:
alex_mod = torchvision.models.alexnet(weights = AlexNet_Weights.DEFAULT)
conv_1 = alex_mod.features[0]

# get the weights of the 1st conv layer
weights = conv_1.weight
num_filters = weights.shape[0]
num_color_channels = weights.shape[1]

# change the shape of the conv_1 layer

before_luma_weights = torch.zeros(64, 3, 121)
for i in range(num_filters):
    temp = weights[i].reshape(weights[i].size(0), -1)
    before_luma_weights[i] = temp

# compute the luma weights
luma_weights = torch.zeros((num_filters, 121, 1))       ## initalize the luma_weights
luma_components = torch.tensor([[0.2989, 0.578, 0.114]])  ## luma components for RGB to grayscale conversion

for i in range(num_filters):
    temp = before_luma_weights[i].T @ luma_components.T
    luma_weights[i] = temp

alex_mod.features[0].weight = torch.nn.Parameter(luma_weights.reshape(64, 1, 11, 11))      # set the new luma weights to the conv2d_1 layer

Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:01<00:00, 196MB/s] 


**Creating the second stream**

In [6]:
class SecondStream(nn.Module):
    def __init__(self):
        super(SecondStream, self).__init__()
        
        modified_alexnet = alex_mod
        
        self.features = modified_alexnet.features # conv layers
        
        # Use AlexNet fc6 and fc7
        self.fc6 = modified_alexnet.classifier[0]  # Linear(9216, 4096)
        self.relu6 = modified_alexnet.classifier[1]
        self.dropout6 = modified_alexnet.classifier[2]

        self.fc7 = modified_alexnet.classifier[3]  # Linear(4096, 4096)
        self.relu7 = modified_alexnet.classifier[4]
        self.dropout7 = modified_alexnet.classifier[5]
        
        self.fc8 = nn.Linear(4096, 2048)
        self.relu8 = nn.ReLU()
        self.dropout8 = nn.Dropout(p=0.5)
        
        self.fc9 = nn.Linear(2048, 2048)
        self.relu9 = nn.ReLU()
        self.dropout9 = nn.Dropout(p=0.5)
        
        self.fc10 = nn.Linear(2048, 531)
        
    def forward(self, x):
        x = self.features(x)
        
        x = torch.flatten(x, 1)  # Flatten to (B, 9216)
        
        x = self.fc6(x)
        x = self.relu6(x)
        x = self.dropout6(x)
        
        x = self.fc7(x)
        x = self.relu7(x)
        x = self.dropout7(x)
        
        x = self.fc8(x)
        x = self.relu8(x)
        x = self.dropout8(x)
        
        x = self.fc9(x)
        x = self.relu9(x)
        x = self.dropout9(x)
        
        x = self.fc10(x)
        
        return x

**Defining both the stream models**

In [7]:
device = torch.device("cuda" if cuda.is_available() else "cpu")

stream_1 = FirstStream()
stream_2 = SecondStream()

# load the state dictionaries
stream_1.load_state_dict(torch.load('/kaggle/input/stream-models/stream1_model.pth'), strict=False)
stream_2.load_state_dict(torch.load('/kaggle/input/stream-models/stream2_model.pth'), strict=False)



_IncompatibleKeys(missing_keys=[], unexpected_keys=['modfc.weight', 'modfc.bias'])

**Creating the dataloader for the joint Stream**

In [8]:
BATCH_SIZE = 32
data_root = '/kaggle/input/11k-hands-training-dataset/content/drive/MyDrive/train_images/train'

# override the ImageFolder to include the custom function
class CustomImageFolder(ImageFolder):
    def __init__(self, root, transform=None):
        super().__init__(root=root, transform=None)  # disable transform for now
        self.base_transform = transform  # keep your transform pipeline without the custom fn

    def __getitem__(self, index):
        path, target = self.samples[index]

        # custom arg is the image itself in your case
        blurred_img = createLowFrequencyComponent(path)
        blurred_img = (blurred_img - blurred_img.min())/(blurred_img.max() - blurred_img.min())
        blurred_img = torch.from_numpy(blurred_img).permute(2, 0, 1).float()

        detailed_img = createHighFrequencyComponent(path)
        detailed_img = cv2.resize(detailed_img, (224, 224))
        detailed_img = np.expand_dims(detailed_img, axis=0)  # shape: (1, 224, 224)
        detailed_img = torch.from_numpy(detailed_img).float()
        
        if self.base_transform is not None:
            blurred_img = self.base_transform(blurred_img)
            detailed_img = self.base_transform(detailed_img)

        return blurred_img, detailed_img, target

base_transform = transforms.Compose([
    transforms.Resize((224, 224))
])

dataset = CustomImageFolder(root=data_root, transform=base_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

**Creating the joint model: TwoStreamNet**

In [9]:
class TwoStreamNet(nn.Module):
    def __init__(self, FirstStream, SecondStream): 
        
        super(TwoStreamNet, self).__init__()
        self.stream1 = FirstStream
        self.stream2 = SecondStream
        
        self.sequential = nn.Sequential(
            nn.Linear(in_features=1062, out_features=1062),
            nn.Unflatten(1, (1, 1062)),
            nn.AvgPool1d(kernel_size=2, stride=2),
            nn.Flatten(), 
            nn.Linear(in_features=531, out_features=2),
            nn.Softmax(dim=1)
        )

    def forward(self, blurred_img, detailed_img):
        f1 = self.stream1(blurred_img)
        f2 = self.stream2(detailed_img)

        x = torch.concat((f1, f2), dim=1)
        x = self.sequential(x)

        return x

**Training the joint model**

In [None]:
model = TwoStreamNet(stream_1, stream_2).to(device)
model.load_state_dict(torch.load('/kaggle/input/stream-models/joint_model.pth'))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.002, momentum=0.9)
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    loop = tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for (blurred_img, detailed_img, labels) in loop:
        blurred_img, detailed_img, labels = blurred_img.to(device), detailed_img.to(device), labels.to(device)

        outputs = model(blurred_img, detailed_img)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * blurred_img.size(0)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
            
        loop.set_postfix(loss=loss.item(), acc=correct/total)
        
    print(f"For Epoch {epoch+1} — Accuracy: {correct/total:.4f}, Loss: {total_loss/total:.4f}")
    torch.save(model.state_dict(), "/kaggle/working/joint_model.pth")

In [None]:
torch.save(model.state_dict(), "/kaggle/working/joint_model.pth")