In [None]:
import torch
import torch.nn as nn
import transformers
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset,random_split
import cv2
import os

In [None]:
image_dir=r"/content/drive/MyDrive/CVC/converted"
mask_dir=r"/content/drive/MyDrive/CVC/converted_labels"

In [None]:
sorted_image_files=sorted(os.listdir(image_dir))
sorted_mask_files=sorted(os.listdir(mask_dir))

In [None]:
image_paths=[]
mask_paths=[]
image_paths = [os.path.join(image_dir, img) for img in sorted_image_files]
mask_paths = [os.path.join(mask_dir, msk) for msk in sorted_mask_files]

In [None]:
image = cv2.imread(image_paths[9])
if image is None:
    raise ValueError(f"Image at {image_paths[9]} could not be loaded.")

mask = cv2.imread(mask_paths[9], cv2.IMREAD_GRAYSCALE)
if mask is None:
    raise ValueError(f"Mask at {mask_paths[9]} could not be loaded.")

In [None]:
class data(Dataset):
  def __init__(self,image_paths,mask_paths,transform=None):
    self.img_paths=image_paths
    self.label_paths=mask_paths
    self.transform=transform
  def __len__(self):
    return len(self.img_paths)
  def __getitem__(self,idx):
    image=cv2.imread(self.img_paths[idx])
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask=cv2.imread(self.label_paths[idx],cv2.IMREAD_GRAYSCALE)
    if self.transform is not None:
      transformed_data=self.transform(image=image,mask=mask)
      image, mask = transformed_data["image"], transformed_data["mask"]
    return image,mask


In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

transform = A.Compose([
    A.Resize(224, 224),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Rotate(limit=20, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=10, p=0.5),
    A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.5),
    A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
    A.MotionBlur(blur_limit=5, p=0.2),
    A.Normalize(mean=(0.5,0.5,0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2()
])
dataset=data(image_paths,mask_paths,transform)
train_size=int(len(dataset)*(0.75))
valid_size=len(dataset)-train_size
train_dataset,val_dataset=random_split(dataset,[train_size,valid_size])
train_loader=DataLoader(train_dataset,batch_size=16,shuffle=True)
val_loader=DataLoader(val_dataset,batch_size=16,shuffle=False)

In [None]:
import torchvision.transforms.functional as TF

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

In [None]:
from transformers import SwinModel
model="microsoft/swin-tiny-patch4-window7-224"
swin=SwinModel.from_pretrained(model)
print("Architecture:")
print(swin)

In [None]:
swin_bottleneck_stage = swin.encoder.layers[2].blocks
for param in swin_bottleneck_stage.parameters():
    param.requires_grad = False

In [None]:
class BottleNeck(nn.Module):

    def __init__(self, encoder_out_channels, swin_blocks_module_list, resolution=(14, 14)):

        super().__init__()


        self.swin_blocks = swin_blocks_module_list


        swin_channels = 384

        # This is a key component: a 1x1 convolution layer to adapt

        self.swin_channel_adapter = nn.Conv2d(
            in_channels=encoder_out_channels,
            out_channels=swin_channels,
            kernel_size=1
        )
        self.re_adapter=nn.Conv2d(in_channels=swin_channels,
                                  out_channels=1024,
                                  kernel_size=1)
        self.resolution = resolution

    def forward(self, x):
        """
        Defines the forward pass of the bottleneck.

        Args:
            x (torch.Tensor): The input feature map from the encoder,
                             with shape (B, encoder_out_channels, H, W).

        Returns:
            torch.Tensor: The processed feature map for the decoder,
                          with shape (B, swin_channels, H, W).
        """
        # 1. Adapt the channels from the encoder (e.g., 512) to the Swin blocks (384).
        # Input shape: (B, 512, 14, 14) -> Output shape: (B, 384, 14, 14)
        x = self.swin_channel_adapter(x)

        # 2. Permute the tensor for the Swin blocks.
        # The Swin blocks expect a shape of (B, H, W, C).
        x = x.permute(0, 2, 3, 1)

        height,width=self.resolution
        # 3. Pass through each Swin block in the ModuleList.
        # This is where the core feature processing happens.
        x=x.view(x.shape[0],height*width,x.shape[-1])

        for block in self.swin_blocks:
            # Each SwinBlock needs the hidden_states and input_resolution.
            x = block(x, self.resolution)[0]

        # 4. Permute the tensor back to the standard (B, C, H, W) format for the decoder.

        x = x.permute(0, 2, 1)
        batch_size=x.shape[0]
        channels=x.shape[1]
        x=x.view(batch_size,channels,height,width)
        x=self.re_adapter(x)

        # The output of this bottleneck is a 14x14 feature map with 384 channels.
        return x

In [None]:
test_input=torch.randn(1,512,14,14)
print(test_input)
model=BottleNeck(512,swin_bottleneck_stage,resolution=(14,14))
output=model(test_input)
print(output.shape)

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels, features=[64, 128, 256, 512]):
        super(Generator, self).__init__()
        self.features = features
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.channel_adjust = nn.ModuleList()
        self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)

        for feature in self.features:
            self.downs.append(DoubleConv(in_channels, feature))
            if in_channels != feature:
                self.channel_adjust.append(nn.Conv2d(in_channels, feature, kernel_size=1))
            else:
                self.channel_adjust.append(nn.Identity())
            in_channels = feature

        # Replace bottleneck with Swin Transformer
        self.BottleNeck = BottleNeck(512,swin_bottleneck_stage,resolution=(14,14))

        for feature in reversed(self.features):
            self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature * 2, feature))

        self.final_conv = nn.Conv2d(self.features[0], out_channels, kernel_size=1)

    def forward(self, x):

        skip_connections = []
        for idx, down in enumerate(self.downs):

            identity = x
            x = down(x)
            identity = self.channel_adjust[idx](identity)
            if x.shape[2:] != identity.shape[2:]:
                identity = TF.resize(identity, size=x.shape[2:])
            x = x + identity
            skip_connections.append(x)
            x = self.MaxPool(x)


        x = self.BottleNeck(x)  # Swin Transformer bottleneck

        skip_connections = skip_connections[::-1]


        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)


            skip_connection = skip_connections[idx // 2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx + 1](concat_skip)


        return self.final_conv(x)

In [None]:
test_input=torch.randn(1,3,224,224)
model=Generator(in_channels=3,out_channels=1,features=[64,128,256,512])
output=model(test_input)
print(output.shape)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=4, features=[64, 128, 256, 512]):

        super(Discriminator, self).__init__()

        # Build the convolutional layers of the discriminator
        layers = []
        layers.append(
            # The first layer takes the concatenated input
            nn.Sequential(
                nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, bias=False),
                nn.LeakyReLU(0.2, inplace=True)
            )
        )

        # Build the rest of the layers
        for i in range(1, len(features)):
            in_feat = features[i - 1]
            out_feat = features[i]
            layers.append(
                nn.Sequential(
                    nn.Conv2d(in_feat, out_feat, kernel_size=4, stride=2, padding=1, bias=False),
                    nn.BatchNorm2d(out_feat),
                    nn.LeakyReLU(0.2, inplace=True)
                )
            )

        self.model = nn.Sequential(*layers)

        # The final convolutional layer to produce a single output
        # (a probability score)
        self.final_conv = nn.Conv2d(features[-1], 1, kernel_size=4, stride=1, padding=0, bias=False)

    def forward(self, x):
.

        # Pass the concatenated tensor through the main model
        x = self.model(x)

        # Apply the final convolutional layer to get a single output score
        return self.final_conv(x)


In [None]:
test_input=torch.randn(1,4,224,224)
model=Discriminator(in_channels=4,features=[64,128,256,512])
output=model(test_input)
print(output.shape)

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__

    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

    elif isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.constant_(m.bias, 0)

    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.weight, 1.0)
        nn.init.constant_(m.bias, 0)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

# Loss and Metrics (as provided)
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)
        BCE = nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.where(targets == 1, inputs, 1 - inputs)
        loss = self.alpha * (1 - pt) ** self.gamma * BCE
        return loss.mean()

class ComboLoss(nn.Module):
    def __init__(self, weight_dice=0.7, weight_bce=0.2, weight_focal=0.7):
        super(ComboLoss, self).__init__()
        self.weight_dice = weight_dice
        self.weight_bce = weight_bce
        self.weight_focal = weight_focal
        self.bce = nn.BCEWithLogitsLoss()
        self.focal = FocalLoss(alpha=0.25, gamma=2)

    def forward(self, preds, masks):
        smooth = 1e-6
        preds_sigmoid = torch.sigmoid(preds)
        intersection = (preds_sigmoid * masks).sum()
        dice_loss = 1 - (2.0 * intersection + smooth) / (preds_sigmoid.sum() + masks.sum() + smooth)
        bce_loss = self.bce(preds, masks)
        focal_loss = self.focal(preds, masks)
        return (self.weight_dice * dice_loss +
                self.weight_bce * bce_loss +
                self.weight_focal * focal_loss)
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    y_true = y_true.flatten()
    y_pred = y_pred.flatten()
    intersection = (y_true * y_pred).sum()
    return (2. * intersection + smooth) / (y_true.sum() + y_pred.sum() + smooth)

def iou(y_true, y_pred, smooth=1e-6):
    y_true = y_true.flatten()
    y_pred = y_pred.flatten()
    intersection = (y_true * y_pred).sum()
    union = y_true.sum() + y_pred.sum() - intersection
    return (intersection + smooth) / (union + smooth)

def accuracy(y_true, y_pred):
    y_true = y_true.flatten()
    y_pred = y_pred.flatten()
    return (y_true == y_pred).sum().item() / y_true.size

# Training Loop
def train(generator, discriminator, dataloader, g_criterion, d_criterion, g_optimizer, d_optimizer, device, update_discriminator=True, lambda_adv=0.1):
    generator.train()
    discriminator.train()
    generator.BottleNeck.eval()
    running_g_loss = 0.0
    running_d_loss = 0.0
    dice_scores, iou_scores, accuracy_scores = [], [], []

    step = 0
    for images, masks in tqdm(dataloader, desc="Training"):
        images, masks = images.to(device), masks.to(device) / 255.0
        masks = masks.unsqueeze(1)

        # Train Generator
        g_optimizer.zero_grad()
        fake_masks = generator(images)
        disc_input_fake=torch.cat([images,fake_masks],dim=1)
        disc_pred_fake = discriminator(disc_input_fake)
        g_loss = g_criterion(fake_masks, masks) + lambda_adv * d_criterion(disc_pred_fake, torch.ones_like(disc_pred_fake, device=device))
        g_loss.backward()
        g_optimizer.step()
        running_g_loss += g_loss.item()

        # Compute train metrics
        predictions = torch.sigmoid(fake_masks) > 0.3
        dice_scores.append(dice_coefficient(masks.cpu().numpy(), predictions.cpu().numpy()))
        iou_scores.append(iou(masks.cpu().numpy(), predictions.cpu().numpy()))
        accuracy_scores.append(accuracy(masks.cpu().numpy(), predictions.cpu().numpy()))

        # Train Discriminator every 3 steps
        if update_discriminator and step % 5 == 0:
            d_optimizer.zero_grad()
            disc_input_real=torch.cat([images,masks],dim=1)
            disc_pred_real = discriminator(disc_input_real)
            disc_input_fake_detach=torch.cat([images,fake_masks.detach()],dim=1)
            disc_pred_fake = discriminator(disc_input_fake_detach)
            d_loss = d_criterion(disc_pred_real, torch.ones_like(disc_pred_real, device=device)) + \
                     d_criterion(disc_pred_fake, torch.zeros_like(disc_pred_fake, device=device))
            d_loss.backward()
            d_optimizer.step()
            running_d_loss += d_loss.item()

        step += 1

    train_dice = np.mean(dice_scores)
    train_iou = np.mean(iou_scores)
    train_acc = np.mean(accuracy_scores)
    return running_g_loss / len(dataloader), running_d_loss / len(dataloader), train_dice, train_iou, train_acc
def validate(generator, dataloader, criterion, device):
    generator.eval()
    running_loss = 0.0
    dice_scores, iou_scores, accuracy_scores = [], [], []

    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Validating"):
            images, masks = images.to(device), masks.to(device) / 255.0
            masks = masks.unsqueeze(1)
            fake_masks = generator(images)
            loss = criterion(fake_masks, masks)
            running_loss += loss.item()

            predictions = torch.sigmoid(fake_masks) > 0.3
            dice_scores.append(dice_coefficient(masks.cpu().numpy(), predictions.cpu().numpy()))
            iou_scores.append(iou(masks.cpu().numpy(), predictions.cpu().numpy()))
            accuracy_scores.append(accuracy(masks.cpu().numpy(), predictions.cpu().numpy()))

    return running_loss / len(dataloader), np.mean(dice_scores), np.mean(iou_scores), np.mean(accuracy_scores)
# Main Function
def main():

    batch_size = 16
    epochs = 150
    lr = 3e-4
    patience = 20
    lambda_adv = 0.1  # Adjust the adversarial loss weight
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    generator = Generator(in_channels=3, out_channels=1).to(device)
    discriminator = Discriminator(in_channels=4).to(device)
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    g_criterion = ComboLoss()
    d_criterion = nn.BCEWithLogitsLoss()
    g_optimizer = optim.Adam(generator.parameters(), lr=lr)
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
    scheduler = ReduceLROnPlateau(g_optimizer, mode='min', patience=5, factor=0.5, verbose=True)

    best_val_loss = float('inf')
    early_stopping_counter = 0
    g_losses, d_losses = [], []
    val_losses, val_dice_scores, val_iou_scores, val_acc_scores = [], [], [], []

    for epoch in range(epochs):
        update_discriminator = epoch % 2 == 0  # Update discriminator every alternate epoch
        g_loss, d_loss, train_dice, train_iou, train_acc = train(
            generator, discriminator, train_loader, g_criterion, d_criterion, g_optimizer, d_optimizer, device, update_discriminator, lambda_adv
        )
        val_loss, val_dice, val_iou, val_acc = validate(generator, val_loader, g_criterion, device)
        scheduler.step(val_loss)

        print(f"Epoch [{epoch+1}/{epochs}] - Train G Loss: {g_loss:.4f}, Train D Loss: {d_loss:.4f} | Train Dice: {train_dice:.4f}, Train IoU: {train_iou:.4f}, Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}, Val IoU: {val_iou:.4f}, Val Acc: {val_acc:.4f}")
        g_losses.append(g_loss)
        d_losses.append(d_loss)
        val_losses.append(val_loss)
        val_dice_scores.append(val_dice)
        val_iou_scores.append(val_iou)
        val_acc_scores.append(val_acc)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(generator.state_dict(), "best_generator.pth")
            torch.save(discriminator.state_dict(), "best_discriminator.pth")
            print("Models saved with improved val loss.")
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1

        if early_stopping_counter >= patience:
            print("Early stopping triggered.")
            break


          # Show sample segmentation results from validation set
        val_iter = iter(val_loader)
        sample_images, sample_masks = next(val_iter)
        sample_images, sample_masks = sample_images.to(device), sample_masks.to(device) / 255.0
        sample_masks = sample_masks.unsqueeze(1)

        generator.eval()
        with torch.no_grad():
          sample_preds = generator(sample_images)
          sample_preds = torch.sigmoid(sample_preds) > 0.3
        # Convert tensors to numpy arrays for plotting
        sample_image_np = sample_images[0].cpu().numpy().transpose(1, 2, 0)
        sample_mask_np = sample_masks[0].squeeze().cpu().numpy()
        sample_pred_np = sample_preds[0].squeeze().cpu().numpy()

        # Plot the results
        plt.figure(figsize=(15, 5))
        plt.subplot(1, 3, 1)
        plt.imshow(sample_image_np)
        plt.title('Original Image')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(sample_mask_np, cmap='gray')
        plt.title('Ground Truth Mask')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(sample_pred_np, cmap='gray')
        plt.title(f'Predicted Mask (Epoch {epoch+1})')
        plt.axis('off')

        plt.show()
        # --- End of new plotting code ---

    epochs_list = range(1, len(g_losses) + 1)

    # Plotting Losses
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_list, g_losses, label='Generator Loss', color='blue')
    plt.plot(epochs_list, d_losses, label='Discriminator Loss', color='orange')
    plt.plot(epochs_list, val_losses, label='Validation Loss', color='green')
    plt.title('Training and Validation Losses over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig("losses_plot.png")

    # Plotting Metrics
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_list, val_dice_scores, label='Validation Dice Score', color='red')
    plt.plot(epochs_list, val_iou_scores, label='Validation IoU Score', color='purple')
    plt.plot(epochs_list, val_acc_scores, label='Validation Accuracy', color='brown')
    plt.title('Validation Metrics over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True)
    plt.savefig("metrics_plot.png")

    plt.show()
if __name__=='__main__':
  main()
