In [None]:

import sys
sys.path.append("./../")
import os
import numpy as np
import random
from torch import einsum
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.utils.data as dataf
from torch.utils.data import Dataset
from scipy import io
from scipy.io import loadmat as loadmat
from sklearn.decomposition import PCA
from torch.nn.parameter import Parameter
import torchvision.transforms.functional as TF
from torch.nn import LayerNorm,Linear,Dropout,Softmax
import time
from PIL import Image
import math
from operator import truediv
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score
from torchsummary import summary
import matplotlib.pyplot as plt
import torch.backends.cudnn as cudnn
import re
from pathlib import Path
import copy

import utils
import logger

cudnn.deterministic = True
cudnn.benchmark = False
     

In [None]:
#Dataset
class HSI_LiDAR_DatasetTrain(torch.utils.data.Dataset):
    def __init__(self, dataset='Trento'):

        HSI = loadmat(f'./{dataset}11x11/HSI_Tr.mat')
        LiDAR = loadmat(f'./{dataset}11x11/LIDAR_Tr.mat')
        label = loadmat(f'./{dataset}11x11/TrLabel.mat')

        self.hs_image = (torch.from_numpy(HSI['Data'].astype(np.float32)).to(torch.float32)).permute(0,3,1,2)
        self.lidar_image = (torch.from_numpy(LiDAR['Data'].astype(np.float32)).to(torch.float32)).permute(0,3,1,2)
        self.lbls = ((torch.from_numpy(label['Data'])-1).long()).reshape(-1)

    def __len__(self):
        return self.hs_image.shape[0]

    def __getitem__(self, i):
        return self.hs_image[i], self.lidar_image[i], self.lbls[i]

class HSI_LiDAR_DatasetTest(torch.utils.data.Dataset):
    def __init__(self, dataset='Trento'):

        HSI = loadmat(f'./{dataset}11x11/HSI_Te.mat')
        LiDAR = loadmat(f'./{dataset}11x11/LIDAR_Te.mat')
        label = loadmat(f'./{dataset}11x11/TeLabel.mat')

        self.hs_image = (torch.from_numpy(HSI['Data'].astype(np.float32)).to(torch.float32)).permute(0,3,1,2)
        self.lidar_image = (torch.from_numpy(LiDAR['Data'].astype(np.float32)).to(torch.float32)).permute(0,3,1,2)
        self.lbls = ((torch.from_numpy(label['Data'])-1).long()).reshape(-1)


    def __len__(self):
        return self.hs_image.shape[0]

    def __getitem__(self, i):
        return self.hs_image[i], self.lidar_image[i], self.lbls[i]
     

In [4]:
from torch import nn
class HetConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,padding = None, bias = None,p = 64, g = 64):
        super(HetConv, self).__init__()
        # Groupwise Convolution
        self.groupwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,groups=g,padding = kernel_size//3, stride = stride)
        # Pointwise Convolution
        self.pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1,groups=p, stride = stride)
    def forward(self, x):
        return self.groupwise_conv(x) + self.pointwise_conv(x)

# Cross-HL Attention Module
class CrossHL_attention(nn.Module):
    def __init__(self, dim, patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.1, proj_drop=0.1):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.dim = dim
        self.Wq = nn.Linear(patches, dim * num_heads , bias=qkv_bias)
        self.Wk = nn.Linear(dim, dim , bias=qkv_bias)
        self.Wv = nn.Linear(patches+1, dim , bias=qkv_bias)
        self.linear_projection = nn.Linear(dim * num_heads, dim)
        self.linear_projection_drop = nn.Dropout(proj_drop)

    def forward(self, x, x2):

        B, N, C = x.shape
        # query vector using lidar data
        query = self.Wq(x2).reshape(B, self.num_heads, self.num_heads, self.dim // self.num_heads).permute(0, 1, 2, 3)

        key = self.Wk(x).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3)
        value = self.Wv(x.transpose(1,2)).reshape(B, C, self.num_heads, self.dim // self.num_heads).permute(0, 2, 3, 1)
        attention = torch.einsum('bhid,bhjd->bhij', key, query) * self.scale
        attention = attention.softmax(dim=-1)

        x = torch.einsum('bhij,bhjd->bhid', attention, value)
        x = x.reshape(B, N, -1)
        x = self.linear_projection(x)
        x = self.linear_projection_drop(x)
        return x
    
    
class GlobalFilter(nn.Module):
        
    def __init__(self, dim, h=14, w=8):
        super().__init__()
        self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
        self.w = w
        self.h = h

    def forward(self, x, spatial_size=None):
        B, N, C = x.shape
        if spatial_size is None:
            a = b = int(math.sqrt(N))
        else:
            a, b = spatial_size
        x = x.view(B, a, b, C) 
        x = x.to(torch.float32)
        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
        weight = nn.Parameter(torch.randn(x.shape) * 0.02)
        x = x * weight
        x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho') # Dimensions will be trimmed to s or zero padded
        x = x.reshape(B, N, C)
        return x
    
    
class MultiLayerPerceptron(nn.Module):
  def __init__(self, dim, mlp_dim):
        super().__init__()
        self.fclayer1 = Linear(dim, mlp_dim)
        self.fclayer2 = Linear(mlp_dim, dim)
        self.act_fn = nn.GELU()
        self.dropout = Dropout(0.1)
        self._init_weights()

  def _init_weights(self):

      nn.init.xavier_uniform_(self.fclayer1.weight)
      nn.init.xavier_uniform_(self.fclayer2.weight)

      nn.init.normal_(self.fclayer1.bias, std=1e-6)
      nn.init.normal_(self.fclayer2.bias, std=1e-6)

  def forward(self, x):
      x = self.fclayer1(x)
      x = self.act_fn(x)
      x = self.dropout(x)
      x = self.fclayer2(x)
      x = self.dropout(x)
      return x
  
  
  
  
class SingleEncoderBlock(nn.Module):
    def __init__(self,dim, num_heads, mlp_dim):
        super().__init__()
        print(f"Encoder started executing")
        self.layer_norm = LayerNorm(dim, eps=1e-6) # First LayerNorm layer

        self.ffn_norm = LayerNorm(dim, eps=1e-6) # Second LayerNorm layer
        # print(f"second layer norm executed")
        self.ffn = MultiLayerPerceptron(dim, mlp_dim) # MLP layer
        # print(f"MLP executed")
        self.cross_hl_attention = CrossHL_attention(dim = dim, patches = 11**2) # Cross-HL Attention layer
        self.Globalfilter = GlobalFilter(dim, h=14, w=8) # Global Filter layer
    def forward(self, x1,x2):
        res = x1
        x = self.layer_norm(x1)
        # print(f"First layer norm executed")
        # print(f"Before entering GFLayer shape is {x.shape}")
        x_gf=self.Globalfilter(x) # Global Filter after normalizing
        # print(f"Global Filter layer executed with shape as {x.shape}")
        # print("Now entering Cross HL layer")
        x_chl= self.cross_hl_attention(x,x2) #Cross attention after apply global filters on x1 (HSI)
        print(f"Cross attention layer executed with shape as {x.shape}")
        x_chl = x_chl + res
        # x = x + res
        res = x
        x_chl = self.ffn_norm(x_chl)
        x_gf= self.ffn_norm(x_gf)
        x=torch.cat((x_chl, x_gf), dim=0)
        x = self.ffn(x)
        x = x + res
        return x
    
    
class Encoder(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_dim=512, depth=2):
        super().__init__()
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(dim, eps=1e-6)
        for _ in range(depth):
            layer = SingleEncoderBlock(dim, num_heads, mlp_dim)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, x, x2):
        for layer_block in self.layer:
            x = layer_block(x, x2)
        encoded = self.encoder_norm(x)
        return encoded[:, 0]


class CrossHL_Transformer(nn.Module):
    def __init__(self, FM, NC, NCLidar, Classes, patchsize):
        super(CrossHL_Transformer, self).__init__()
        self.patchsize = patchsize
        self.NCLidar = NCLidar
        self.conv5 = nn.Sequential(
            nn.Conv3d(1, 8, (9, 3, 3), padding=(0, 1, 1), stride=1),
            nn.BatchNorm3d(8),
            nn.ReLU()
        )
        self.hetconv_layer = nn.Sequential(
            HetConv(8 * (NC - 8), FM*4, p=1, g=(FM*4)//4 if (8 * (NC - 8))%FM == 0 else (FM*4)//8),
            nn.BatchNorm2d(FM*4),
            nn.ReLU()
        )
        self.ca = Encoder(FM*4)
        self.fclayer = nn.Linear(FM*4, Classes)
        self.position_embeddings = nn.Parameter(torch.randn(1, (patchsize**2), FM*4))
        self.dropout = nn.Dropout(0.1)
        torch.nn.init.xavier_uniform_(self.fclayer.weight)
        torch.nn.init.normal_(self.fclayer.bias, std=1e-6)

    def forward(self, x1, x2):
        x1 = x1.reshape(x1.shape[0], -1, self.patchsize, self.patchsize).unsqueeze(1).to(device)
        x2 = x2.reshape(x1.shape[0], -1, self.patchsize*self.patchsize).to(device)
        if x2.shape[1] > 0:
            x2 = F.adaptive_avg_pool1d(x2.flatten(2).transpose(1, 2), 1).transpose(1, 2).reshape(x1.shape[0], -1, self.patchsize*self.patchsize)
        x1 = self.conv5(x1)
        x1 = x1.reshape(x1.shape[0], -1, self.patchsize, self.patchsize)
        x1 = self.hetconv_layer(x1)
        x1 = x1.flatten(2).transpose(-1, -2)
        x = x1 + self.position_embeddings
        x = self.dropout(x)
        x = self.ca(x, x2)
        x = x.reshape(x.shape[0], -1)
        out = self.fclayer(x)
        return out
    

In [None]:
## For GPU
import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as dataf
import time
from torchsummary import summary
from torch.nn import LayerNorm, Linear, Dropout, Softmax
from torch.nn.modules.container import Sequential
import copy
import torch.nn.functional as F
import torch.fft
import math
from functools import partial
import logger

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class HetConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, bias=None, p=64, g=64):
        super(HetConv, self).__init__()
        self.groupwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, groups=g, padding=kernel_size//3, stride=stride)
        self.pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=p, stride=stride)
    def forward(self, x):
        return self.groupwise_conv(x) + self.pointwise_conv(x)

class CrossHL_attention(nn.Module):
    def __init__(self, dim, patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.1, proj_drop=0.1):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.dim = dim
        self.Wq = nn.Linear(patches, dim * num_heads, bias=qkv_bias)
        self.Wk = nn.Linear(dim, dim, bias=qkv_bias)
        self.Wv = nn.Linear(patches, dim, bias=qkv_bias)
        self.linear_projection = nn.Linear(dim * num_heads, dim)
        self.linear_projection_drop = nn.Dropout(proj_drop)

    def forward(self, x, x2):
        B, N, C = x.shape
        query = self.Wq(x2).reshape(B, self.num_heads, self.num_heads, self.dim // self.num_heads).permute(0, 1, 2, 3)
        key = self.Wk(x).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3)
        value = self.Wv(x.transpose(1,2)).reshape(B, C, self.num_heads, self.dim // self.num_heads).permute(0, 2, 3, 1)
        attention = torch.einsum('bhid,bhjd->bhij', key, query) * self.scale
        attention = attention.softmax(dim=-1)
        x = torch.einsum('bhij,bhjd->bhid', attention, value)
        x = x.reshape(B, N, -1)
        x = self.linear_projection(x)
        x = self.linear_projection_drop(x)
        return x

class GlobalFilter(nn.Module):
    def __init__(self, dim, h=14, w=8):
        super().__init__()
        self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
        self.w = w
        self.h = h

    def forward(self, x, spatial_size=None):
        B, N, C = x.shape
        if spatial_size is None:
            a = b = int(math.sqrt(N))
        else:
            a, b = spatial_size
        x = x.view(B, a, b, C)
        x = x.to(torch.float32)
        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
        weight = nn.Parameter(torch.randn(x.shape) * 0.02).to(device)
        x = x * weight
        x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho')
        x = x.reshape(B, N, C)
        return x

class MultiLayerPerceptron(nn.Module):
    def __init__(self, dim, mlp_dim):
        super().__init__()
        self.fclayer1 = Linear(dim, mlp_dim)
        self.fclayer2 = Linear(mlp_dim, dim)
        self.act_fn = nn.GELU()
        self.dropout = Dropout(0.1)
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fclayer1.weight)
        nn.init.xavier_uniform_(self.fclayer2.weight)
        nn.init.normal_(self.fclayer1.bias, std=1e-6)
        nn.init.normal_(self.fclayer2.bias, std=1e-6)

    def forward(self, x):
        x = self.fclayer1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fclayer2(x)
        x = self.dropout(x)
        return x

class SingleEncoderBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_dim):
        super().__init__()
        self.attention_norm = LayerNorm(dim, eps=1e-6)
        self.ffn_norm = LayerNorm(dim, eps=1e-6)
        self.ffn = MultiLayerPerceptron(dim, mlp_dim)
        self.cross_hl_attention = CrossHL_attention(dim=dim, patches=11**2)
        self.Globalfilter = GlobalFilter(dim, h=14, w=8)
    def forward(self, x1, x2):
        res = x1
        x = self.layer_norm(x1)
        x_gf=self.Globalfilter(x) 
        x_chl= self.cross_hl_attention(x,x2)
        print(f"Cross attention layer executed with shape as {x.shape}")
        x_chl = x_chl + res
        # x = x + res
        res = x
        x_chl = self.ffn_norm(x_chl)
        x_gf= self.ffn_norm(x_gf)
        x=torch.cat((x_chl, x_gf), dim=0)
        x = self.ffn(x)
        x = x + res
        return x

class Encoder(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_dim=512, depth=2):
        super().__init__()
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(dim, eps=1e-6)
        for _ in range(depth):
            layer = SingleEncoderBlock(dim, num_heads, mlp_dim)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, x, x2):
        for layer_block in self.layer:
            x = layer_block(x, x2)
        encoded = self.encoder_norm(x)
        return encoded[:, 0]

class CrossHL_Transformer(nn.Module):
    def __init__(self, FM, NC, NCLidar, Classes, patchsize):
        super(CrossHL_Transformer, self).__init__()
        self.patchsize = patchsize
        self.NCLidar = NCLidar
        self.conv5 = nn.Sequential(
            nn.Conv3d(1, 8, (9, 3, 3), padding=(0, 1, 1), stride=1),
            nn.BatchNorm3d(8),
            nn.ReLU()
        )
        self.hetconv_layer = nn.Sequential(
            HetConv(8 * (NC - 8), FM*4, p=1, g=(FM*4)//4 if (8 * (NC - 8))%FM == 0 else (FM*4)//8),
            nn.BatchNorm2d(FM*4),
            nn.ReLU()
        )
        self.ca = Encoder(FM*4)
        self.fclayer = nn.Linear(FM*4, Classes)
        self.position_embeddings = nn.Parameter(torch.randn(1, (patchsize**2), FM*4))
        self.dropout = nn.Dropout(0.1)
        torch.nn.init.xavier_uniform_(self.fclayer.weight)
        torch.nn.init.normal_(self.fclayer.bias, std=1e-6)

    def forward(self, x1, x2):
        x1 = x1.reshape(x1.shape[0], -1, self.patchsize, self.patchsize).unsqueeze(1).to(device)
        x2 = x2.reshape(x1.shape[0], -1, self.patchsize*self.patchsize).to(device)
        if x2.shape[1] > 0:
            x2 = F.adaptive_avg_pool1d(x2.flatten(2).transpose(1, 2), 1).transpose(1, 2).reshape(x1.shape[0], -1, self.patchsize*self.patchsize)
        x1 = self.conv5(x1)
        x1 = x1.reshape(x1.shape[0], -1, self.patchsize, self.patchsize)
        x1 = self.hetconv_layer(x1)
        x1 = x1.flatten(2).transpose(-1, -2)
        x = x1 + self.position_embeddings
        x = self.dropout(x)
        x = self.ca(x, x2)
        x = x.reshape(x.shape[0], -1)
        out = self.fclayer(x)
        return out

# Dataset and other parts of the code where data is loaded and moved to GPU should be updated accordingly.


In [None]:
## For GPU
import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as dataf
import time
from torchsummary import summary
from torch.nn import LayerNorm, Linear, Dropout, Softmax
from torch.nn.modules.container import Sequential
import copy
import torch.nn.functional as F
import torch.fft
import math
from functools import partial
import logger

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class HetConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, bias=None, p=64, g=64):
        super(HetConv, self).__init__()
        self.groupwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, groups=g, padding=kernel_size//3, stride=stride)
        self.pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=p, stride=stride)
    def forward(self, x):
        return self.groupwise_conv(x) + self.pointwise_conv(x)

class CrossHL_attention(nn.Module):
    def __init__(self, dim, patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.1, proj_drop=0.1):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.dim = dim
        self.Wq = nn.Linear(patches, dim * num_heads, bias=qkv_bias)
        self.Wk = nn.Linear(dim, dim, bias=qkv_bias)
        self.Wv = nn.Linear(patches, dim, bias=qkv_bias)
        self.linear_projection = nn.Linear(dim * num_heads, dim)
        self.linear_projection_drop = nn.Dropout(proj_drop)

    def forward(self, x, x2):
        B, N, C = x.shape
        query = self.Wq(x2).reshape(B, self.num_heads, self.num_heads, self.dim // self.num_heads).permute(0, 1, 2, 3)
        key = self.Wk(x).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3)
        value = self.Wv(x.transpose(1,2)).reshape(B, C, self.num_heads, self.dim // self.num_heads).permute(0, 2, 3, 1)
        attention = torch.einsum('bhid,bhjd->bhij', key, query) * self.scale
        attention = attention.softmax(dim=-1)
        x = torch.einsum('bhij,bhjd->bhid', attention, value)
        x = x.reshape(B, N, -1)
        x = self.linear_projection(x)
        x = self.linear_projection_drop(x)
        return x

class GlobalFilter(nn.Module):
    def __init__(self, dim, h=14, w=8):
        super().__init__()
        self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
        self.w = w
        self.h = h

    def forward(self, x, spatial_size=None):
        B, N, C = x.shape
        if spatial_size is None:
            a = b = int(math.sqrt(N))
        else:
            a, b = spatial_size
        x = x.view(B, a, b, C)
        x = x.to(torch.float32)
        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
        weight = nn.Parameter(torch.randn(x.shape) * 0.02).to(device)
        x = x * weight
        x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho')
        x = x.reshape(B, N, C)
        return x

class MultiLayerPerceptron(nn.Module):
    def __init__(self, dim, mlp_dim):
        super().__init__()
        self.fclayer1 = Linear(dim, mlp_dim)
        self.fclayer2 = Linear(mlp_dim, dim)
        self.act_fn = nn.GELU()
        self.dropout = Dropout(0.1)
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fclayer1.weight)
        nn.init.xavier_uniform_(self.fclayer2.weight)
        nn.init.normal_(self.fclayer1.bias, std=1e-6)
        nn.init.normal_(self.fclayer2.bias, std=1e-6)

    def forward(self, x):
        x = self.fclayer1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fclayer2(x)
        x = self.dropout(x)
        return x

class SingleEncoderBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_dim):
        super().__init__()
        self.attention_norm = LayerNorm(dim, eps=1e-6)
        self.ffn_norm = LayerNorm(dim, eps=1e-6)
        self.ffn = MultiLayerPerceptron(dim, mlp_dim)
        self.cross_hl_attention = CrossHL_attention(dim=dim, patches=11**2)
        self.Globalfilter = GlobalFilter(dim, h=14, w=8)
    def forward(self, x1, x2):
        res = x1
        x = self.attention_norm(x1)
        x_gf=self.Globalfilter(x) 
        x_chl= self.cross_hl_attention(x,x2)
        print(f"Cross attention layer executed with shape as {x.shape}")
        x_chl = x_chl + res
        # x = x + res
        res = x
        x_chl = self.ffn_norm(x_chl)
        x_gf= self.ffn_norm(x_gf)
        x=torch.cat((x_chl, x_gf), dim=0)
        x = self.ffn(x)
        x = x + res
        return x

class Encoder(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_dim=512, depth=2):
        super().__init__()
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(dim, eps=1e-6)
        for _ in range(depth):
            layer = SingleEncoderBlock(dim, num_heads, mlp_dim)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, x, x2):
        for layer_block in self.layer:
            x = layer_block(x, x2)
        encoded = self.encoder_norm(x)
        return encoded[:, 0]

class CrossHL_Transformer(nn.Module):
    def __init__(self, FM, NC, NCLidar, Classes, patchsize):
        super(CrossHL_Transformer, self).__init__()
        self.patchsize = patchsize
        self.NCLidar = NCLidar
        self.conv5 = nn.Sequential(
            nn.Conv3d(1, 8, (9, 3, 3), padding=(0, 1, 1), stride=1),
            nn.BatchNorm3d(8),
            nn.ReLU()
        )
        self.hetconv_layer = nn.Sequential(
            HetConv(8 * (NC - 8), FM*4, p=1, g=(FM*4)//4 if (8 * (NC - 8))%FM == 0 else (FM*4)//8),
            nn.BatchNorm2d(FM*4),
            nn.ReLU()
        )
        self.ca = Encoder(FM*4)
        self.fclayer = nn.Linear(FM*4, Classes)
        self.position_embeddings = nn.Parameter(torch.randn(1, (patchsize**2), FM*4))
        self.dropout = nn.Dropout(0.1)
        torch.nn.init.xavier_uniform_(self.fclayer.weight)
        torch.nn.init.normal_(self.fclayer.bias, std=1e-6)

    def forward(self, x1, x2):
        x1 = x1.reshape(x1.shape[0], -1, self.patchsize, self.patchsize).unsqueeze(1).to(device)
        x2 = x2.reshape(x1.shape[0], -1, self.patchsize*self.patchsize).to(device)
        if x2.shape[1] > 0:
            x2 = F.adaptive_avg_pool1d(x2.flatten(2).transpose(1, 2), 1).transpose(1, 2).reshape(x1.shape[0], -1, self.patchsize*self.patchsize)
        x1 = self.conv5(x1)
        x1 = x1.reshape(x1.shape[0], -1, self.patchsize, self.patchsize)
        x1 = self.hetconv_layer(x1)
        x1 = x1.flatten(2).transpose(-1, -2)
        x = x1 + self.position_embeddings
        x = self.dropout(x)
        x = self.ca(x, x2)
        x = x.reshape(x.shape[0], -1)
        out = self.fclayer(x)
        return out

# Dataset and other parts of the code where data is loaded and moved to GPU should be updated accordingly.
