In [None]:
#sparse with low rank-attention
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader, random_split
from timm import create_model
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import random
from thop import profile
import numpy as np

# Set random seed for reproducibility
random_seed = 42  # You can choose any integer
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)

# If using CUDA, also set the seed for GPU
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)  # If using multi-GPU

# Dataset transformations
train_val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Dataset loading
dataset_path = 'Kather_texture_2016_image_tiles_5000'
dataset = datasets.ImageFolder(dataset_path)
test_size = int(len(dataset) * 0.2)
train_size = len(dataset) - test_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

val_size = int(train_size * 0.1)
train_size = train_size - val_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

train_dataset.dataset.transform = train_val_transform
val_dataset.dataset.transform = train_val_transform
test_dataset.dataset.transform = test_transform

# DataLoaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Teacher models (ViT, DeiT, Swin)
teacher_vit = create_model('vit_base_patch16_224', pretrained=True, num_classes=len(dataset.classes)).cuda()
teacher_deit = create_model('deit_base_patch16_224', pretrained=True, num_classes=len(dataset.classes)).cuda()
teacher_swin = create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=len(dataset.classes)).cuda()

# Low-Rank Sparse Multi-Head Attention
class LowRankSparseMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, rank, sparsity_ratio=0.5):
        super(LowRankSparseMultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.rank = rank
        self.sparsity_ratio = sparsity_ratio
        
        # Independent low-rank projections for each head, using full embed_dim
        self.q_lows = nn.ModuleList([nn.Linear(embed_dim, rank, bias=False) for _ in range(num_heads)])
        self.q_highs = nn.ModuleList([nn.Linear(rank, embed_dim, bias=False) for _ in range(num_heads)])
        
        self.k_lows = nn.ModuleList([nn.Linear(embed_dim, rank, bias=False) for _ in range(num_heads)])
        self.k_highs = nn.ModuleList([nn.Linear(rank, embed_dim, bias=False) for _ in range(num_heads)])
        
        self.v_lows = nn.ModuleList([nn.Linear(embed_dim, rank, bias=False) for _ in range(num_heads)])
        self.v_highs = nn.ModuleList([nn.Linear(rank, embed_dim, bias=False) for _ in range(num_heads)])
        
        self.out_proj = nn.Linear(embed_dim * num_heads, embed_dim)
        self.scale = embed_dim ** -0.5

    def sparse_attention(self, attn_scores, sparsity_ratio):
        """Apply sparsity to the attention scores by masking out low-ranked elements."""
        batch_size, num_heads, seq_length, _ = attn_scores.size()

        if seq_length == 1:
            return attn_scores  # Dense attention if sequence length is 1

        num_to_keep = int(sparsity_ratio * seq_length)
        top_scores, _ = torch.topk(attn_scores, k=num_to_keep, dim=-1)
        threshold = top_scores.min(dim=-1, keepdim=True)[0]
        sparse_mask = attn_scores >= threshold
        return attn_scores * sparse_mask.float()

    def forward(self, x):
        batch_size, seq_length, embed_dim = x.size()

        q, k, v = [], [], []
        for i in range(self.num_heads):
            # Use full embed_dim for each head
            q_head = self.q_highs[i](self.q_lows[i](x)).view(batch_size, seq_length, embed_dim)
            k_head = self.k_highs[i](self.k_lows[i](x)).view(batch_size, seq_length, embed_dim)
            v_head = self.v_highs[i](self.v_lows[i](x)).view(batch_size, seq_length, embed_dim)
            q.append(q_head)
            k.append(k_head)
            v.append(v_head)

        q = torch.stack(q, dim=1)  # (batch_size, num_heads, seq_length, embed_dim)
        k = torch.stack(k, dim=1)
        v = torch.stack(v, dim=1)

        # Scaled dot-product attention per head
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # Apply sparse attention
        sparse_attn_scores = self.sparse_attention(attn_scores, self.sparsity_ratio)

        attn_probs = F.softmax(sparse_attn_scores, dim=-1)
        attn_output = torch.matmul(attn_probs, v)

        # Concatenate output from all heads along the embedding dimension
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.num_heads * embed_dim)

        # Final projection
        output = self.out_proj(attn_output)
        return output
    
# Custom DeiT Layer using Low-Rank Sparse Attention
class CustomDeiTLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, rank, mlp_ratio=4., drop_path=0.1, sparsity_ratio=0.5):
        super(CustomDeiTLayer, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = LowRankSparseMultiheadAttention(embed_dim, num_heads, rank, sparsity_ratio)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = nn.LayerNorm(embed_dim)
        
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, embed_dim),
        )

    def forward(self, x):
        # Apply low-rank sparse multi-head attention
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

# Hybrid Student Model
class HybridStudentModel(nn.Module):
    def __init__(self, num_classes, embed_dim=768, num_heads=8, num_layers=2, rank=32, drop_path_rate=0.95, sparsity_ratio=0.9):
        super(HybridStudentModel, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.densenet = models.densenet121(pretrained=True)
        
        # Custom DeiT layers with low-rank sparse attention
        self.deit_layers = nn.ModuleList([
            CustomDeiTLayer(embed_dim, num_heads, rank, drop_path=drop_path_rate, sparsity_ratio=sparsity_ratio) for _ in range(num_layers)
        ])
        self.deit_embed = nn.Linear(2000, embed_dim)
        self.norm = nn.LayerNorm(embed_dim)
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # CNN feature extraction
        resnet_feats = self.resnet(x)
        densenet_feats = self.densenet(x)
        combined_feats = torch.cat((resnet_feats, densenet_feats), dim=1)
        combined_feats = combined_feats.view(combined_feats.size(0), -1)  # Flatten
        
        # Add sequence dimension (seq_length=1)
        combined_feats = combined_feats.unsqueeze(1)
        
        # Embedding and Transformer block
        x = self.deit_embed(combined_feats)
        for layer in self.deit_layers:
            x = layer(x)
        
        x = self.norm(x)
        x = x.squeeze(1)  # Remove sequence dimension
        x = self.classifier(x)
        return x

# Distillation Loss
class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=3.0):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_div = nn.KLDivLoss(reduction="batchmean")

    def forward(self, student_logits, teacher_logits, ground_truth):
        hard_loss = self.ce_loss(student_logits, ground_truth)
        soft_loss = self.kl_div(
            F.log_softmax(student_logits / self.temperature, dim=1),
            F.softmax(teacher_logits / self.temperature, dim=1)
        ) * (self.temperature ** 2)
        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss

# Evaluate Model
def evaluate_model(model, data_loader, criterion):
    model.eval()
    correct, total, test_loss = 0, 0, 0.0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    avg_loss = test_loss / len(data_loader)
    return accuracy, avg_loss

# Train Model with Distillation
def train_model_with_distillation(student_model, teacher_models, train_loader, val_loader, distillation_criterion, optimizer, num_epochs=20):
    train_losses, val_losses, train_accuracies, val_accuracies = [], [], [], []
    for epoch in range(num_epochs):
        student_model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            student_outputs = student_model(images)
            with torch.no_grad():
                teacher_logits = [teacher(images) for teacher in teacher_models]
                combined_teacher_logits = sum(teacher_logits) / len(teacher_logits)
            loss = distillation_criterion(student_outputs, combined_teacher_logits, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        train_loss = running_loss / len(train_loader)
        train_accuracy, _ = evaluate_model(student_model, train_loader, distillation_criterion.ce_loss)
        val_accuracy, val_loss = evaluate_model(student_model, val_loader, distillation_criterion.ce_loss)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_accuracy)
        val_accuracies.append(val_accuracy)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, '
              f'Train Acc: {train_accuracy:.2f}%, Val Acc: {val_accuracy:.2f}%')
    return student_model, train_losses, val_losses, train_accuracies, val_accuracies

# Instantiate and train the model
num_classes = len(dataset.classes)
hybrid_model = HybridStudentModel(num_classes).cuda()
criterion = DistillationLoss(alpha=0.5, temperature=3.0)
optimizer = optim.SGD(hybrid_model.parameters(), lr=0.01, momentum=0.9)

teacher_models = [teacher_vit, teacher_deit, teacher_swin]

# Train the student model
trained_model, train_losses, val_losses, train_accuracies, val_accuracies = train_model_with_distillation(
    hybrid_model, teacher_models, train_loader, val_loader, criterion, optimizer, num_epochs=3
)

# Evaluate the model
test_accuracy, test_loss = evaluate_model(trained_model, test_loader, criterion.ce_loss)
print(f'Test Accuracy: {test_accuracy:.2f}%, Test Loss: {test_loss:.4f}')

# Function to count the number of parameters in a model
def count_custom_parameters(model, exclude_list=None):
    """
    Count the number of trainable parameters excluding those from specified submodules.

    """
    exclude_list = exclude_list or []
    excluded_params = set(p for submodule in exclude_list for p in submodule.parameters())

    return sum(p.numel() for p in model.parameters() if p.requires_grad and p not in excluded_params)

# Instantiate the model
hybrid_model = HybridStudentModel(num_classes).cuda()

# Exclude the ResNet and DenseNet parameters
excluded_models = [hybrid_model.resnet, hybrid_model.densenet]

# Count parameters in custom layers only
custom_params = count_custom_parameters(hybrid_model, exclude_list=excluded_models)
print(f'Number of trainable parameters in custom layers: {custom_params}')

# Create a dummy input for FLOP calculation
dummy_input = torch.randn(1, 3, 224, 224).cuda()

# Calculate FLOPs using thop
flops, params = profile(hybrid_model, inputs=(dummy_input,))
print(f'Number of FLOPs in the hybrid model: {flops:.2e}')

# Convert FLOPs to GHz
flops_per_cycle = 8 * 10**9  # 8 FLOPs per clock cycle for a GHz processor
required_ghz = flops / flops_per_cycle
print(f'Approximate GHz required: {required_ghz:.4f} GHz')


In [None]:
#simple multihead attention
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader, random_split
from timm import create_model
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import random
from thop import profile
import numpy as np

# Set random seed for reproducibility
random_seed = 42  # You can choose any integer
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)

# If using CUDA, also set the seed for GPU
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)  # If using multi-GPU

# Dataset transformations
train_val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Dataset loading
dataset_path = 'Kather_texture_2016_image_tiles_5000'
dataset = datasets.ImageFolder(dataset_path)
test_size = int(len(dataset) * 0.2)
train_size = len(dataset) - test_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

val_size = int(train_size * 0.1)
train_size = train_size - val_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

train_dataset.dataset.transform = train_val_transform
val_dataset.dataset.transform = train_val_transform
test_dataset.dataset.transform = test_transform

# DataLoaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Teacher models (ViT, DeiT, Swin)
teacher_vit = create_model('vit_base_patch16_224', pretrained=True, num_classes=len(dataset.classes)).cuda()
teacher_deit = create_model('deit_base_patch16_224', pretrained=True, num_classes=len(dataset.classes)).cuda()
teacher_swin = create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=len(dataset.classes)).cuda()

# Low-Rank Sparse Multi-Head Attention
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.scale = embed_dim ** -0.5  # Scaling factor for attention

        # Independent linear projections for each head (all using embed_dim, not divided)
        self.q_proj_heads = nn.ModuleList([nn.Linear(embed_dim, embed_dim) for _ in range(num_heads)])
        self.k_proj_heads = nn.ModuleList([nn.Linear(embed_dim, embed_dim) for _ in range(num_heads)])
        self.v_proj_heads = nn.ModuleList([nn.Linear(embed_dim, embed_dim) for _ in range(num_heads)])

        # Final output projection
        self.out_proj = nn.Linear(embed_dim * num_heads, embed_dim)

    def forward(self, x):
        batch_size, seq_length, embed_dim = x.size()
        

        q, k, v = [], [], []
        for i in range(self.num_heads):
            # Use the full embed_dim for each head's query, key, and value projections
            q_head = self.q_proj_heads[i](x).view(batch_size, seq_length, embed_dim)
            k_head = self.k_proj_heads[i](x).view(batch_size, seq_length, embed_dim)
            v_head = self.v_proj_heads[i](x).view(batch_size, seq_length, embed_dim)
            q.append(q_head)
            k.append(k_head)
            v.append(v_head)

        # Stack the queries, keys, and values across heads
        q = torch.stack(q, dim=1)  # (batch_size, num_heads, seq_length, embed_dim)
        k = torch.stack(k, dim=1)
        v = torch.stack(v, dim=1)

        # Scaled dot-product attention for each head
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn_probs = F.softmax(attn_scores, dim=-1)

        # Attention output for each head
        attn_output = torch.matmul(attn_probs, v)

        # Concatenate outputs from all heads and project to final output
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.num_heads * embed_dim)

        output = self.out_proj(attn_output)
        return output
    
# Custom DeiT Layer using Low-Rank Sparse Attention
class CustomDeiTLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, rank, mlp_ratio=4., drop_path=0.1, sparsity_ratio=0.5):
        super(CustomDeiTLayer, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiheadAttention(embed_dim, num_heads)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = nn.LayerNorm(embed_dim)
        
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, embed_dim),
        )

    def forward(self, x):
        # Apply low-rank sparse multi-head attention
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

# Hybrid Student Model
class HybridStudentModel(nn.Module):
    def __init__(self, num_classes, embed_dim=768, num_heads=8, num_layers=2, rank=32, drop_path_rate=0.1, sparsity_ratio=0.5):
        super(HybridStudentModel, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.densenet = models.densenet121(pretrained=True)
        
        # Custom DeiT layers with low-rank sparse attention
        self.deit_layers = nn.ModuleList([
            CustomDeiTLayer(embed_dim, num_heads, rank, drop_path=drop_path_rate, sparsity_ratio=sparsity_ratio) for _ in range(num_layers)
        ])
        self.deit_embed = nn.Linear(2000, embed_dim)
        self.norm = nn.LayerNorm(embed_dim)
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # CNN feature extraction
        resnet_feats = self.resnet(x)
        densenet_feats = self.densenet(x)
        combined_feats = torch.cat((resnet_feats, densenet_feats), dim=1)
        combined_feats = combined_feats.view(combined_feats.size(0), -1)  # Flatten
        
        # Add sequence dimension (seq_length=1)
        combined_feats = combined_feats.unsqueeze(1)
        
        # Embedding and Transformer block
        x = self.deit_embed(combined_feats)
        for layer in self.deit_layers:
            x = layer(x)
        
        x = self.norm(x)
        x = x.squeeze(1)  # Remove sequence dimension
        x = self.classifier(x)
        return x

# Distillation Loss
class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=3.0):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_div = nn.KLDivLoss(reduction="batchmean")

    def forward(self, student_logits, teacher_logits, ground_truth):
        hard_loss = self.ce_loss(student_logits, ground_truth)
        soft_loss = self.kl_div(
            F.log_softmax(student_logits / self.temperature, dim=1),
            F.softmax(teacher_logits / self.temperature, dim=1)
        ) * (self.temperature ** 2)
        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss

# Evaluate Model
def evaluate_model(model, data_loader, criterion):
    model.eval()
    correct, total, test_loss = 0, 0, 0.0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    avg_loss = test_loss / len(data_loader)
    return accuracy, avg_loss

# Train Model with Distillation
def train_model_with_distillation(student_model, teacher_models, train_loader, val_loader, distillation_criterion, optimizer, num_epochs=20):
    train_losses, val_losses, train_accuracies, val_accuracies = [], [], [], []
    for epoch in range(num_epochs):
        student_model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            student_outputs = student_model(images)
            with torch.no_grad():
                teacher_logits = [teacher(images) for teacher in teacher_models]
                combined_teacher_logits = sum(teacher_logits) / len(teacher_logits)
            loss = distillation_criterion(student_outputs, combined_teacher_logits, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        train_loss = running_loss / len(train_loader)
        train_accuracy, _ = evaluate_model(student_model, train_loader, distillation_criterion.ce_loss)
        val_accuracy, val_loss = evaluate_model(student_model, val_loader, distillation_criterion.ce_loss)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_accuracy)
        val_accuracies.append(val_accuracy)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, '
              f'Train Acc: {train_accuracy:.2f}%, Val Acc: {val_accuracy:.2f}%')
    return student_model, train_losses, val_losses, train_accuracies, val_accuracies

# Instantiate and train the model
num_classes = len(dataset.classes)
hybrid_model = HybridStudentModel(num_classes).cuda()
criterion = DistillationLoss(alpha=0.5, temperature=3.0)
optimizer = optim.SGD(hybrid_model.parameters(), lr=0.01, momentum=0.9)

teacher_models = [teacher_vit, teacher_deit, teacher_swin]

# Train the student model
trained_model, train_losses, val_losses, train_accuracies, val_accuracies = train_model_with_distillation(
    hybrid_model, teacher_models, train_loader, val_loader, criterion, optimizer, num_epochs=2
)

# Evaluate the model
test_accuracy, test_loss = evaluate_model(trained_model, test_loader, criterion.ce_loss)
print(f'Test Accuracy: {test_accuracy:.2f}%, Test Loss: {test_loss:.4f}')

# Function to count the number of parameters in a model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Print the number of parameters in the student model
num_params = count_parameters(hybrid_model)
print(f'Number of trainable parameters in the hybrid model: {num_params}')

# Create a dummy input for FLOP calculation
dummy_input = torch.randn(1, 3, 224, 224).cuda()

# Calculate FLOPs using thop
flops, params = profile(hybrid_model, inputs=(dummy_input,))
print(f'Number of FLOPs in the hybrid model: {flops:.2e}')

# Convert FLOPs to GHz
flops_per_cycle = 8 * 10**9  # 8 FLOPs per clock cycle for a GHz processor
required_ghz = flops / flops_per_cycle
print(f'Approximate GHz required: {required_ghz:.4f} GHz')

###########
