In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.cuda.amp import autocast, GradScaler

# Encoder Class
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        resnet = models.resnet18(weights='ResNet18_Weights.DEFAULT')

        # Extract stages from ResNet18
        self.stage1 = nn.Sequential(*list(resnet.children())[:4])  # Initial layers
        self.stage2 = nn.Sequential(*list(resnet.children())[4:6])  # First residual block
        self.stage3 = nn.Sequential(*list(resnet.children())[6:7])  # Second residual block
        self.stage4 = nn.Sequential(*list(resnet.children())[7:8])  # Third residual block

    def forward(self, x):
        F1 = self.stage1(x)  # 64 channels
        F2 = self.stage2(F1)  # 128 channels
        F3 = self.stage3(F2)  # 256 channels
        F4 = self.stage4(F3)  # 512 channels
        return F1, F2, F3, F4


# MSTB Class
class MSTB(nn.Module):
    def __init__(self, d_model, nhead, reduction_factor=4):
        super(MSTB, self).__init__()
        self.local_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, batch_first=True)
        self.global_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, batch_first=True)

        # Dimensionality reduction for regional features
        self.reduce_kv = nn.Conv2d(d_model, d_model // reduction_factor, kernel_size=1)

        # Feedforward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, d_model)
        )

    def forward(self, F_local, F_region):
        B, C, H, W = F_local.shape
    
        # Flatten local and regional features for attention
        F_local_flat = F_local.contiguous().view(B, C, -1).permute(0, 2, 1)  # (B, HW, C)
        F_region_flat = F_region.contiguous().view(B, C, -1).permute(0, 2, 1)  # (B, HW, C)
    
        # Dimensionality reduction for K and V
        F_region_reduced = self.reduce_kv(F_region).contiguous().view(B, -1, H * W).permute(0, 2, 1)  # (B, HW, reduced_C)
    
        # Compute local and global attention
        local_out, _ = self.local_attn(F_local_flat, F_local_flat, F_local_flat)
        global_out, _ = self.global_attn(F_local_flat, F_region_flat, F_region_flat)
    
        # Combine outputs
        combined = local_out + global_out  # (B, HW, d_model)
    
        # Pass through FFN
        combined_flat = combined.reshape(-1, combined.size(-1))  # (B * HW, d_model)
        refined_flat = self.ffn(combined_flat)  # (B * HW, d_model)
        refined = refined_flat.reshape(B, H * W, -1).permute(0, 2, 1).reshape(B, C, H, W)  # Reshape back to (B, C, H, W)
    
        return refined




# Decoder Class
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU()
        )

        self.mstb = MSTB(d_model=512, nhead=8, reduction_factor=4)

        # Upsampling layers
        self.upsample1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)  # 8x8 → 16x16
        self.upsample2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)  # 16x16 → 32x32
        self.upsample3 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)  # 32x32 → 64x64

        # Final reconstruction layers
        self.final_upsample = nn.Upsample(size=(256, 256), mode='bilinear', align_corners=False)
        self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)  # Output RGB image

    def forward(self, F4_E):
        # Initial processing with convolution block
        FConv = self.conv_block(F4_E)

        # Apply MSTB
        FTrans = self.mstb(F4_E, F4_E)

        # Upsample and reconstruct
        F3_D = self.upsample1(FTrans)
        F2_D = self.upsample2(F3_D)
        F1_D = self.upsample3(F2_D)
        reconstructed_image = self.final_upsample(F1_D)
        reconstructed_image = self.final_conv(reconstructed_image)

        return F1_D, F2_D, F3_D, FTrans, reconstructed_image


# Feature Comparison Loss
class FeatureComparisonLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3):
        super(FeatureComparisonLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.mse = nn.MSELoss()

    def forward(self, F_E, F_D, original_image, reconstructed_image):
        feature_loss = sum(self.mse(f_e, f_d) for f_e, f_d in zip(F_E, F_D)) / len(F_E)
        reconstruction_loss = self.mse(original_image, reconstructed_image)
        return self.alpha * feature_loss + self.beta * reconstruction_loss


# Initialize Models, Optimizer, and Criterion
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder().to(device)
decoder = Decoder().to(device)
criterion = FeatureComparisonLoss()
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.0001)

# Training Loop
train_dataset = ImageFolder(root="/kaggle/input/brain-data", transform=transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]))
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

num_epochs = 200
scaler = GradScaler()

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    running_loss = 0.0

    for inputs, _ in train_loader:
        inputs = inputs.to(device)
        optimizer.zero_grad()

        with autocast():
            F1_E, F2_E, F3_E, F4_E = encoder(inputs)
            F1_D, F2_D, F3_D, F4_D, reconstructed_image = decoder(F4_E)
            loss = criterion([F1_E, F2_E, F3_E, F4_E], [F1_D, F2_D, F3_D, F4_D], inputs, reconstructed_image)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader):.4f}")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 82.0MB/s]
  scaler = GradScaler()
  with autocast():


Epoch 1, Loss: 0.1758
Epoch 2, Loss: 0.0829
Epoch 3, Loss: 0.0646
Epoch 4, Loss: 0.0541
Epoch 5, Loss: 0.0491
Epoch 6, Loss: 0.0461
Epoch 7, Loss: 0.0429
Epoch 8, Loss: 0.0416
Epoch 9, Loss: 0.0401
Epoch 10, Loss: 0.0403
Epoch 11, Loss: 0.0423
Epoch 12, Loss: 0.0376
Epoch 13, Loss: 0.0362
Epoch 14, Loss: 0.0365
Epoch 15, Loss: 0.0356
Epoch 16, Loss: 0.0349
Epoch 17, Loss: 0.0352
Epoch 18, Loss: 0.0342
Epoch 19, Loss: 0.0346
Epoch 20, Loss: 0.0329
Epoch 21, Loss: 0.0344
Epoch 22, Loss: 0.0326
Epoch 23, Loss: 0.0327
Epoch 24, Loss: 0.0334
Epoch 25, Loss: 0.0323
Epoch 26, Loss: 0.0349
Epoch 27, Loss: 0.0331
Epoch 28, Loss: 0.0319
Epoch 29, Loss: 0.0319
Epoch 30, Loss: 0.0311
Epoch 31, Loss: 0.0305
Epoch 32, Loss: 0.0307
Epoch 33, Loss: 0.0314
Epoch 34, Loss: 0.0305
Epoch 35, Loss: 0.0295
Epoch 36, Loss: 0.0291
Epoch 37, Loss: 0.0287
Epoch 38, Loss: 0.0288
Epoch 39, Loss: 0.0285
Epoch 40, Loss: 0.0280
Epoch 41, Loss: 0.0276
Epoch 42, Loss: 0.0278
Epoch 43, Loss: 0.0277
Epoch 44, Loss: 0.02

In [2]:

print("F1_E:", F1_E.shape, "F1_D:", F1_D.shape)
print("F2_E:", F2_E.shape, "F2_D:", F2_D.shape)
print("F3_E:", F3_E.shape, "F3_D:", F3_D.shape)
print("F4_E:", F4_E.shape, "F4_D:", F4_D.shape)

F1_E: torch.Size([11, 64, 64, 64]) F1_D: torch.Size([11, 64, 64, 64])
F2_E: torch.Size([11, 128, 32, 32]) F2_D: torch.Size([11, 128, 32, 32])
F3_E: torch.Size([11, 256, 16, 16]) F3_D: torch.Size([11, 256, 16, 16])
F4_E: torch.Size([11, 512, 8, 8]) F4_D: torch.Size([11, 512, 8, 8])


In [3]:
# Save the models
torch.save(encoder.state_dict(), "encoder_model.pth")
torch.save(decoder.state_dict(), "decoder_model.pth")
print("Models saved successfully.")

Models saved successfully.
