In [2]:
!pip install timm

Defaulting to user installation because normal site-packages is not writeable


In [3]:
import os
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import CrossEntropyLoss
import timm
import brevitas.nn as qnn
from torchvision import transforms as T
from PIL import Image
from glob import glob
import random
import torch.nn as nn
import torch.nn.functional as F
from brevitas.core.quant import QuantType

# Set the seed for reproducibility
torch.manual_seed(2023)

# Define your paths
base_dir = os.environ['FINN_ROOT'] + "/notebooks/FINN_Brevitas/"
data_path = base_dir + "Dataset_BUSI_with_GT/"

# Custom Dataset Class
class CustomDataset(Dataset):
    def __init__(self, root, transformations=None):
        self.transformations = transformations
        self.im_paths = []
        self.cls_names = {}
        
        # Load all images and infer class labels from folder names
        class_folders = sorted(glob(f"{root}/*"))  # e.g., Benign, Malignant, Normal
        print(f"Found class folders: {class_folders}")
        
        for idx, folder in enumerate(class_folders):
            class_name = os.path.basename(folder)  # e.g., "Benign"
            self.cls_names[class_name] = idx
            folder_images = glob(f"{folder}/*.png")
            print(f"Class '{class_name}' has {len(folder_images)} images.")
            self.im_paths.extend([(im_path, idx) for im_path in folder_images])
        
        self.im_paths.sort()  # Sort for reproducibility
    
    def __len__(self):
        return len(self.im_paths)
    
    def __getitem__(self, idx):
        im_path, label = self.im_paths[idx]
        im = Image.open(im_path).convert("L")  # Convert to grayscale
        if self.transformations:
            im = self.transformations(im)
        return im, label
# Split dataset
def split_dataset(dataset, train_ratio=0.8, val_ratio=0.1):
    total_len = len(dataset)
    train_len = int(total_len * train_ratio)
    val_len = int(total_len * val_ratio)
    test_len = total_len - train_len - val_len
    return random_split(dataset, [train_len, val_len, test_len])

# Dataloader Function
def get_dls(root, transformations, bs, ns=4):
    dataset = CustomDataset(root=root, transformations=transformations)
    train_ds, val_ds, test_ds = split_dataset(dataset)
    tr_dl = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=4)
    val_dl = DataLoader(val_ds, batch_size=bs, shuffle=False, num_workers=4)
    ts_dl = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=ns)
    return tr_dl, val_dl, ts_dl, dataset.cls_names

# Root path for dataset
root = data_path
# Define advanced transformations with grayscale and augmentations
mean, std, im_size = [0.485], [0.229], 224
tfs = T.Compose([
    T.Resize((im_size, im_size)),
    T.RandomHorizontalFlip(),
    T.RandomRotation(20),
    T.RandomResizedCrop(im_size, scale=(0.8, 1.0)),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    T.RandomAffine(degrees=20, translate=(0.1, 0.1), shear=10),
    T.RandomGrayscale(p=1.0),  # Ensure images are converted to grayscale
    T.ToTensor(),
    T.RandomErasing(p=0.2),
    T.Normalize(mean=mean, std=std)
])

In [4]:
# Create DataLoaders
tr_dl, val_dl, ts_dl, classes = get_dls(root=root, transformations=tfs, bs=16)
print("Number of batches in training loader:", len(tr_dl))
print("Number of batches in validation loader:", len(val_dl))
print("Number of batches in test loader:", len(ts_dl))
print("Classes:", classes)

Found class folders: ['/home/administrateur/finn/notebooks/FINN_Brevitas/Dataset_BUSI_with_GT/Benign', '/home/administrateur/finn/notebooks/FINN_Brevitas/Dataset_BUSI_with_GT/Malignant', '/home/administrateur/finn/notebooks/FINN_Brevitas/Dataset_BUSI_with_GT/Normal']
Class 'Benign' has 437 images.
Class 'Malignant' has 210 images.
Class 'Normal' has 133 images.
Number of batches in training loader: 39
Number of batches in validation loader: 5
Number of batches in test loader: 78
Classes: {'Benign': 0, 'Malignant': 1, 'Normal': 2}


In [5]:
# SEBlock and ResidualBlock Definitions
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(channels // reduction, channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        batch, channels, _, _ = x.size()
        y = self.global_pool(x).view(batch, channels)
        y = self.fc2(self.act(self.fc1(y))).view(batch, channels, 1, 1)
        y = self.sigmoid(y)
        return x * y

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None, use_se=False):
        super(ResidualBlock, self).__init__()
        self.conv1 = qnn.QuantConv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, weight_bit_width=8, bias=False, quant_type=QuantType.INT)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.act1 = qnn.QuantReLU(bit_width=8, quant_type=QuantType.INT)
        self.conv2 = qnn.QuantConv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, weight_bit_width=8, bias=False, quant_type=QuantType.INT)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.act2 = qnn.QuantReLU(bit_width=8, quant_type=QuantType.INT)
        self.use_se = use_se
        if use_se:
            self.se = SEBlock(out_channels)

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.use_se:
            out = self.se(out)
        out += identity
        out = self.act2(out)
        return out

# Model Definition
class AdvancedUltrasoundNet(nn.Module):
    def __init__(self, num_classes):
        super(AdvancedUltrasoundNet, self).__init__()
        self.conv1 = qnn.QuantConv2d(1, 64, kernel_size=7, stride=2, padding=3, weight_bit_width=8, bias=False, quant_type=QuantType.INT)
        self.bn1 = nn.BatchNorm2d(64)
        self.act1 = qnn.QuantReLU(bit_width=8, quant_type=QuantType.INT)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Residual Blocks with SE
        self.res1 = ResidualBlock(64, 64, use_se=True)
        self.res2 = ResidualBlock(64, 128, stride=2, downsample=self._downsample(64, 128), use_se=True)
        self.res3 = ResidualBlock(128, 256, stride=2, downsample=self._downsample(128, 256), use_se=True)
        self.res4 = ResidualBlock(256, 512, stride=2, downsample=self._downsample(256, 512), use_se=True)

        # Fully Connected Layers
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = qnn.QuantLinear(512, 256, bias=True, weight_bit_width=8, quant_type=QuantType.INT)
        self.act2 = qnn.QuantReLU(bit_width=8, quant_type=QuantType.INT)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = qnn.QuantLinear(256, 128, bias=True, weight_bit_width=8, quant_type=QuantType.INT)
        self.fc3 = qnn.QuantLinear(128, num_classes, bias=True, weight_bit_width=8, quant_type=QuantType.INT)

    def _downsample(self, in_channels, out_channels):
        return nn.Sequential(
            qnn.QuantConv2d(in_channels, out_channels, kernel_size=1, stride=2, weight_bit_width=8, bias=False, quant_type=QuantType.INT),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        x = self.pool1(self.act1(self.bn1(self.conv1(x))))
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = self.res4(x)
        x = self.global_pool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(self.act2(self.fc1(x)))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model, criterion, optimizer, and scheduler
model = AdvancedUltrasoundNet(num_classes=len(classes))

In [6]:
# Loss and optimizer
criterion = CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)  # Adjusted learning rate and weight decay
scheduler = CosineAnnealingLR(optimizer, T_max=200, eta_min=1e-6)  # Adjusted Cosine Annealing with warm-up strategy
clip_value = 1.0  # Gradient clipping to stabilize training

In [7]:
# Training loop
def train_model(model, tr_dl, val_dl, criterion, optimizer, scheduler, epochs=200):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in tr_dl:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        scheduler.step()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(tr_dl):.4f}, Accuracy: {100 * correct/total:.2f}%")

        # Validation phase
        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for inputs, labels in val_dl:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        print(f"Validation Loss: {val_loss/len(val_dl):.4f}, Validation Accuracy: {100 * correct_val/total_val:.2f}%")

In [8]:
# Train the model
train_model(model, tr_dl, val_dl, criterion, optimizer, scheduler, epochs=200)

  return super(Tensor, self).rename(names)


Epoch [1/200], Loss: 1.0190, Accuracy: 56.09%
Validation Loss: 0.9579, Validation Accuracy: 53.85%
Epoch [2/200], Loss: 0.9619, Accuracy: 55.93%
Validation Loss: 0.8727, Validation Accuracy: 62.82%
Epoch [3/200], Loss: 0.9408, Accuracy: 57.05%
Validation Loss: 0.8581, Validation Accuracy: 62.82%
Epoch [4/200], Loss: 0.9146, Accuracy: 58.17%
Validation Loss: 0.7886, Validation Accuracy: 71.79%
Epoch [5/200], Loss: 0.9020, Accuracy: 58.81%
Validation Loss: 0.9192, Validation Accuracy: 55.13%
Epoch [6/200], Loss: 0.8817, Accuracy: 59.78%
Validation Loss: 0.8659, Validation Accuracy: 55.13%
Epoch [7/200], Loss: 0.8548, Accuracy: 60.58%
Validation Loss: 0.8426, Validation Accuracy: 56.41%
Epoch [8/200], Loss: 0.8339, Accuracy: 62.98%
Validation Loss: 0.6726, Validation Accuracy: 65.38%
Epoch [9/200], Loss: 0.8023, Accuracy: 64.10%
Validation Loss: 0.7437, Validation Accuracy: 65.38%
Epoch [10/200], Loss: 0.7870, Accuracy: 65.06%
Validation Loss: 0.6926, Validation Accuracy: 69.23%
Epoch [11

In [None]:
# Save the trained model
torch.save(model.state_dict(), "resquantized_cnn.pth")

In [None]:
# Evaluate on test set
model.eval()
correct_test = 0
total_test = 0
with torch.no_grad():
    for inputs, labels in ts_dl:
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total_test += labels.size(0)
        correct_test += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct_test/total_test:.2f}%")