In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Sequential as Seq, Linear as Lin, Conv2d
import numpy as np

In [None]:
train_pure = '/kaggle/input/breast-256/gcn_256/Train_gcn/*'
val_pure =  '/kaggle/input/breast-256/gcn_256/Val_gcn/*'

In [None]:
import glob
train_files = glob.glob(train_pure)
val_files = glob.glob(val_pure)
len(train_files)#6229

In [None]:
!pip install barbar
!pip install swin-transformer-pytorch

In [None]:
from torchvision import transforms
#import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from PIL import Image
import albumentations
import albumentations.pytorch 
import cv2
import torch.nn as nn
import copy
from barbar import Bar
import io
import numpy as np

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
torch.cuda.device_count()

In [None]:
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.benchmark=True
    torch.backends.cudnn.deterministic=False  

In [None]:
class My_data(Dataset):
    def __init__(self, data, transforms=None):
        self.image_list = data
        self.data_len = len(self.image_list)
        self.transforms = transforms
        self.eicls = ["A", "F", "TA", "PT", "DC", "LC", "MC", "PC"]
        
    def __getitem__(self, index):
        current_image_path = self.image_list[index]
        im_as_im = cv2.imread(current_image_path)
        im_as_im = cv2.cvtColor(im_as_im, cv2.COLOR_BGR2RGB)

        # Perform label encoding for multi-label classification
        parts = current_image_path.split('_')[-1].split('-')
        if parts[2]=="13412":
            labels =[0,0,0,0,1,1,0,0]
        else:
            labels = [int(label == parts[0]) for label in self.eicls]       
        labels = torch.tensor(labels)

        if self.transforms is not None:
            augmented = self.transforms(image=im_as_im)
            im_as_im = augmented['image']

        return (im_as_im, labels)

    def __len__(self):
        return self.data_len

In [None]:
transform = {
    'train': albumentations.Compose([
    albumentations.Resize(256, 256),     
    albumentations.OneOf([
                          albumentations.HorizontalFlip(),
                          albumentations.Rotate(limit=45),
                          albumentations.VerticalFlip(),
                          albumentations.GaussianBlur(),
                          albumentations.NoOp()
    ], p=1),
    albumentations.Normalize(mean=(0.787, 0.625, 0.765),
                       std=(0.105, 0.138, 0.089), p=1),
    albumentations.pytorch.transforms.ToTensorV2()]),
    
    'valid': albumentations.Compose([
    albumentations.Resize(256, 256),     
    albumentations.Normalize(mean=(0.786, 0.623, 0.766),
                       std=(0.105, 0.138, 0.089), p=1),
    albumentations.pytorch.transforms.ToTensorV2()]),
    
}

In [None]:
train=My_data(train_files,transforms=transform['train'])
valid=My_data(val_files,transforms=transform['valid'])

a=torch.tensor([0,0,0,0,0,0,0,0])
for _,label in train:
    a=a+label
print(a)     


In [None]:
import torch

class_samples = [367, 803, 456, 370, 2763, 492, 629, 449]  # Number of samples in each class
total_samples = sum(class_samples)
samples=total_samples/len(class_samples)
class_weights = [samples / (s + 1e-8) for s in class_samples]
class_weights = torch.tensor(class_weights)
print(class_weights)

In [None]:
len(valid)

In [None]:
train_dataloader = torch.utils.data.DataLoader(dataset=train, batch_size=16,shuffle=True,num_workers=2,
                                              pin_memory=True,prefetch_factor=2)
valid_dataloader=  torch.utils.data.DataLoader(dataset=valid,batch_size=16,shuffle=False,num_workers=2,
                                               pin_memory=True ,prefetch_factor=2)  

In [None]:
for i in valid_dataloader:
    print(i)
    break    

In [None]:
#pip install timm

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import timm
from timm.models import create_model
from timm.data import create_transform
from sklearn.metrics import accuracy_score

# Define device
# Define Swin Transformer v2 model
"""model_name = 'swin_base_patch4_window7_224'
num_classes = 8
model = create_model(
    model_name=model_name,
    pretrained=True,
    num_classes=num_classes,
    drop_rate=0.5,
    drop_path_rate=0.2,
    checkpoint_path=None
)"""
model = timm.create_model(
    'swinv2_tiny_window8_256.ms_in1k',
    pretrained=True,
    features_only=False,
    num_classes = 8,
    drop_path_rate=0.2,
    drop_rate=0.5
)


"""for param in model.head.parameters():
    param.requires_grad = True
for param in model.norm.parameters():
    param.requires_grad = True  """ 
#model=model.to(device)

In [None]:
#model

import torch


# Iterate over the parameters and check requires_grad
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"Parameter '{name}' requires grad.")
    else:
        print(f"Parameter '{name}' does not require grad.")


import torch
import torch.nn as nn

class FocalLossWithClassWeights(nn.Module):
    def __init__(self, class_weights, alpha=1, gamma=2):
        super(FocalLossWithClassWeights, self).__init__()
        self.class_weights = class_weights
        self.alpha = alpha
        self.gamma = gamma
        self.sigmoid = nn.Sigmoid()

    def forward(self, input, target):
        class_weights = self.class_weights.to(target.device)
        weighted_logits = class_weights * input
        probs = self.sigmoid(weighted_logits)

        loss = -(self.alpha * torch.pow(1 - probs, self.gamma) * target * torch.log(probs + 1e-8)
                 + (1 - target) * torch.log(1 - probs + 1e-8))

        return loss.mean()



In [None]:
import torch
import torch.nn as nn

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, class_weights=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.class_weights = class_weights

    def forward(self, logits, labels):
        probs = torch.sigmoid(logits)
        ce_loss = nn.BCELoss()(probs, labels)
        weight = (1 - probs).pow(self.gamma)
        loss = ce_loss  # Initialize loss with cross-entropy loss
        if self.class_weights is not None:
            weight = weight * self.class_weights
            loss = loss * weight
        return loss


for param in model.parameters():
    param.requires_grad =True
    #print(param.requires_grad)


In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
#model = nn.DataParallel(model, device_ids = [0, 1])
model = model.to(device)
class_weights=class_weights.to(device)
#criterion = torch.nn.BCEWithLogitsLoss(weight=class_weights_normalized)
#criterion = FocalLossWithClassWeights(class_weights)
criterion = FocalLoss(class_weights)
optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), 
    lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.5)
best_model_wts = model.state_dict()
best_optimizer_state =optimizer.state_dict()
best_acc = 0.0

In [None]:
def fit(model, dataloader, optimizer,scheduler, criterion):
    #print('Training')
    model.train()
    train_running_loss = 0.0
    train_running_correct = 0
    accum_iter = 4  
    
    for i, (inputs, labels) in enumerate(Bar(dataloader)):
            inputs = inputs.to(device)           
            labels = labels.float().to(device)
            optimizer.zero_grad()
            #model.zero_grad(set_to_none=True)
            # Forward pass - compute outputs on input data using the model
            outputs = model(inputs)
            thresholds = [0.5, 0.5, 0.5,0.5,0.5,0.5,0.5,0.5]
            # Compute loss
            loss = criterion(outputs, labels)
            train_running_loss += loss.item()* inputs.size(0)
           # _ , preds = torch.max(outputs.data, 1)
            # Apply sigmoid activation to obtain probabilities
            #preds = (outputs > 0.5).float()
            probs = torch.sigmoid(outputs)
            preds = torch.zeros_like(probs)
            
            # Set predicted labels based on the threshold
            for i, threshold in enumerate(thresholds):
                preds[:, i] = (probs[:, i] >= threshold).float()
            train_running_correct += (preds == labels).all(dim=1).float().sum()
            # Backpropagate the gradients
            loss /= accum_iter
            loss.backward() 
            
                       
            if ((i + 1) % accum_iter == 0) :
                optimizer.step()
                optimizer.zero_grad()
                    
            
    scheduler.step()
            
    train_loss = train_running_loss/len(dataloader.dataset)
    train_accuracy = 100. * train_running_correct/len(dataloader.dataset)    
    return train_loss, train_accuracy

In [None]:
def validate(model, dataloader, optimizer, criterion):
    #print('Validating')
    model.eval()
    val_running_loss = 0.0
    val_running_correct = 0
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloader):
            inputs = inputs.to(device)
            labels = labels.float()
            labels = labels.to(device)
            outputs = model(inputs)
            thresholds = [0.5, 0.5, 0.5,0.5,0.5,0.5,0.5,0.5]
            loss = criterion(outputs, labels)
            
            val_running_loss += loss.item()*inputs.size(0)
            #_, preds = torch.max(outputs.data, 1)
            #preds = (outputs > 0.5).float()
            probs = torch.sigmoid(outputs)
            preds = torch.zeros_like(probs)
            # Set predicted labels based on the threshold
            for i, threshold in enumerate(thresholds):
                preds[:, i] = (probs[:, i] >= threshold).float()
            val_running_correct += (preds == labels).all(dim=1).float().sum()
        
    val_loss = val_running_loss/len(dataloader.dataset)
    val_accuracy = 100. * val_running_correct/len(dataloader.dataset)        
    return val_loss, val_accuracy

import time as time
history=[]
best_model_wts = copy.deepcopy(model.state_dict())
#best_optimizer_state = copy.deepcopy(optimizer.state_dict())
best_acc = 0.0
epochs=50

for epoch in range(epochs):
    epoch_start = time.time()
    print('Epoch-{0}/{1} lr: {2}'.format(epoch+1,epochs ,optimizer.param_groups[0]['lr']))
    if  epoch > 14:
        for param in model.parameters():
            param.requires_grad =True            
    #print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_epoch_accuracy = fit(model,train_dataloader,optimizer,scheduler,criterion)
    val_epoch_loss, val_epoch_accuracy = validate(model,valid_dataloader,optimizer,criterion)
    
    epoch_end = time.time()
    history.append([epoch+1,train_epoch_loss, train_epoch_accuracy, val_epoch_loss, val_epoch_accuracy,(epoch_end-epoch_start)])
    print(f"Train Loss: {train_epoch_loss:.4f}, Train Acc: {train_epoch_accuracy:.2f},Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_accuracy:.2f},time : {epoch_end-epoch_start:.2f}")
    torch.save({'history':history},'Master_his.pth')
    if val_epoch_accuracy > best_acc:
        best_acc = val_epoch_accuracy
        best_model_wts = copy.deepcopy(model.state_dict())
       
        best_epoch=epoch
        torch.save({
            'epoch': epoch+1,
            'model_state_dict': best_model_wts,
            'loss': criterion,
            'history':history,
            'best_epoch': best_epoch+1,          
    
            }, 'Master.pth')    
  

In [None]:
import time as time
history=[]
best_model_wts = copy.deepcopy(model.state_dict())
#best_optimizer_state = copy.deepcopy(optimizer.state_dict())
best_acc = 0.0
epochs=50

for epoch in range(epochs):
    epoch_start = time.time()
    print('Epoch-{0}/{1} lr: {2}'.format(epoch+1,epochs ,optimizer.param_groups[0]['lr']))
    
    train_epoch_loss, train_epoch_accuracy = fit(model,train_dataloader,optimizer,scheduler,criterion)
    val_epoch_loss, val_epoch_accuracy = validate(model,valid_dataloader,optimizer,criterion)
    
    epoch_end = time.time()
    history.append([epoch+1,train_epoch_loss, train_epoch_accuracy, val_epoch_loss, val_epoch_accuracy,(epoch_end-epoch_start)])
    print(f"Train Loss: {train_epoch_loss:.4f}, Train Acc: {train_epoch_accuracy:.2f},Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_accuracy:.2f},time : {epoch_end-epoch_start:.2f}")
    torch.save({'history':history},'eft_his_tiny.pth')
    if val_epoch_accuracy > best_acc:
        best_acc = val_epoch_accuracy
        best_model_wts = copy.deepcopy(model.state_dict())
       
        best_epoch=epoch
        torch.save({
            'epoch': epoch+1,
            'model_state_dict': best_model_wts,
            'loss': criterion,
            'history':history,
            'best_epoch': best_epoch+1,          
    
            }, 'eft_tiny.pth')    
  