<a href="https://colab.research.google.com/github/carlosspino/SWSSL/blob/main/SWSSL_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **NEURAL NETWORK PROJECT: SWSSL, Sliding window-based self-supervised learning for anomaly detection in high-resolution images **



We extend anomaly detection to high-resolution images by proposing to train the network and perform inference at the patch level, through the sliding window algorithm. The model is trained on chest or DBT (Digital Breast Tomosynthesis) images, and the training process involves learning features that can discriminate between positive (normal) and negative (anomalous) samples.


We have 5 methods in our train_network_dbt.py class which are:

-twin_loss(f_patch1, f_patch2, f_neg=None, p=False, target=None)

-train(model, device, args)


-create_dataloader(dataset, batch_size, shuffle=True, drop_last=False)

-evaluate_and_save_model(epoch, model, args, best_score)

-get_args()

***twin_loss Method***

**1-Unpack data:**

  We have to understand the batch size and featuer dimension of the image patches. Necessary to calculate the loss and to ensure that comptations are perfomed well on batches of data.

**2-Feature normalizations:**

We use the Z-score, we subtract the mean and divide by the standard deviation of each feature. It helps the model to converge faster.

**3-The positive score: **

How to quantify how similar the patches are to each other.

**4-Calculation of Difference and Loss:**

We calculate the difference between the positive score and an identity matrix.
The difference allows us to penalize any discrepancies between the elements of the positive score. Summing the diagonal of this difference is giving us an initial loss, represents the discrepancy between the features of patches that we aim to minimize during training.

**5-Weight for Non-diagonal elements:**

We get the non-diagonal elements of the difference to penalize discrepancies between the non-diagonal elements of the positive score.
Then, we multiply the difference by this weight and sum it to the initial loss to get the total loss.

**6-Additional Loss Calculation:**

Distinguish between pairs of features that shoulld be similar and those that should be different.

In [None]:
def twin_loss(f_patch1, f_patch2, f_neg=None, p=False, target=None):
    batch_size, dimension = f_patch1.shape

    # Features normalizations
    f_patch1_norm = F.normalize(f_patch1, dim=-1)
    f_patch2_norm = F.normalize(f_patch2, dim=-1)

    # Calculation of positive loss
    pos_score = torch.mm(f_patch1_norm, f_patch2_norm.t()) / batch_size
    diff = (pos_score - torch.eye(batch_size).cuda()).pow(2)
    loss = diff.diag().sum()

    # Non-diagonal loss weighting
    non_diag_weight = (torch.ones([batch_size, batch_size]) - torch.eye(batch_size)) * 1e-6
    non_diag_weight = non_diag_weight.cuda()
    diff *= non_diag_weight
    loss += diff.sum()

    if f_neg is not None:
        # Negative features normalization
        f_neg_norm = F.normalize(f_neg, dim=-1)

        # Loss calculation for positive and negative pairs
        pair_score = torch.mm(f_patch1_norm, f_patch2_norm.t())
        pair_sim = torch.sigmoid(pair_score.diag())
        pair_loss = torch.abs(pair_sim - torch.ones(batch_size).cuda()).sum()

        neg_score = torch.mm(f_patch1_norm, f_neg_norm.t())
        neg_sim = torch.sigmoid(neg_score.diag())
        neg_loss = torch.abs(neg_sim - target).sum()

        # Loss sum
        loss += neg_loss + pair_loss

    # Some prints for debugging and have some tracking info
    if p:
        if f_neg is not None:
            print('pair loss ', pair_loss.item())
            print('neighbor loss ', neg_loss.item())
        print('total loss ', loss.item())
        print('feature sample:')
        print(f_patch1_norm[0][:10])
        print(f_patch2_norm[0][:10])
        print(f_patch1_norm[1][:10])

    return loss

***train Method***, *3 aux methods explained too*

**1-create_dataset:**

We create dataset depending on its category (chest or dbt) and its phase (train or val). Uses the "ChestDataset" and "DBTDataset" classes to create datasets.

**2-create_dataloader:**

Create dataloaders for the dataset created in the previous method. Sets up the dataloaders with the batch size specified and other options like shuffling and dropping last incomplete batch.

**3-evaluate_and_save_model:**

Evaluates the model's performance and saves the model checkpoint if the current score is better than the best score seen so far. It calls the 'evaluate_image' method to evaluate the model on the train and test datasets.

**4-Dataloader setup:**

These dataloaders will be used during the training loop to iterate over batches of data.

**5-Optimizer Setup:**

Uses a stochastic gradient descent (SGD) with momentum and weight decay. It will be used to update the model's parameters based on the calculated gradients during training.

**6-Training Loop:**

It performs backpropagation to update the model's parameters based on the calculated loss.

**7-Model evaluation and saving:**

Is evaluated every 10 epochs using the 'evaluate_and_save_model' method. If the model's performance improves, the model checkpoint is saved.


In [None]:
def train(model, device, args):

    def create_dataset(category, phase, patch=True):
        transforms_list = [
            transforms.Resize((256*4, 256*4), Image.ANTIALIAS),
            transforms.ToTensor()
        ] if patch else [
            transforms.Resize((256*4, 256*3), Image.ANTIALIAS),
        ]

        transforms_list = transforms.Compose(transforms_list)

        if category == 'chest':
            return ChestDataset(
               root=args.dataset_path,
                pre_transform=transforms_list,
            phase=phase,
                patch=patch,
                patch_size=args.patch_size,
                step_size=args.step_size
            )
        elif category == 'dbt':
            return DBTDataset(
                root=args.dataset_path,
                pre_transform=transforms_list,
                phase=phase,
                patch=patch,
                patch_size=args.patch_size,
                step_size=args.step_size
            )
        else:
            raise ValueError(f"Invalid category: {category}")

    def create_dataloader(dataset, batch_size, shuffle=True, drop_last=False):
        return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)

    def evaluate_and_save_model(epoch, model, args, best_score):
        twin_loss(f_patch, f_aug, f_neg=f_patch2, target=sim, p=1)
        score = evaluate_image(args, model, train_loader, test_loader, device, category=args.category)
        if score > best_score:
            torch.save(model.state_dict(), f'checkpoints/{args.category}_{epoch}_{score}.pth')
            best_score = score
        print(f'img lv curr acc {score}, best acc {best_score}')


    # Dataloader
    train_patch_d = create_dataset(args.category, 'train', patch=True)
    train_full_d = create_dataset(args.category, 'train', patch=False)
    test_full_d = create_dataset(args.category, 'val', patch=False)

    train_patch_loader = create_dataloader(train_patch_d, args.batch_size, shuffle=True)
    train_loader = create_dataloader(train_full_d, args.batch_size, shuffle=False, drop_last=False)
    test_loader = create_dataloader(test_full_d, 1, shuffle=False)

    # Optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-5)

    best_score = -1
    score = evaluate_image(args, model, train_loader, test_loader, device, category=args.category)

    for epoch in range(args.epochs):
        with tqdm(total=len(train_patch_d), desc=f'Epoch {epoch + 1} / {args.epochs}', unit='img') as pbar:
            for idx, data in enumerate(train_patch_loader):
                img, img_aug, img_2, sim = data

                img = img.to(device)
                img_2 = img_2.to(device)
                img_aug = img_aug.to(device)
                sim = sim.to(device)

                f_patch, tmp = model(img)
                f_patch2, _ = model(img_2)
                f_aug, _ = model(img_aug)

                loss = twin_loss(f_patch, f_aug, f_neg=f_patch2, target=sim)

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(model.parameters(), 0.1)
                optimizer.step()

                # tqdm Update
                pbar.set_postfix(**{'twin loss': loss.item()})
                pbar.update(img.shape[0])

        # Evaluate
        if epoch > 0 and epoch % 10 == 0:
            evaluate_and_save_model(epoch, model, args, best_score)

***def_args() Method:***

**1-** This method uses 'argparse' module to analyze the arguments of the command line. Define a parser and add some arguments that can be given through the command line to the script. Some of these arguments are the training phase, the path of the dataset, batchsize, hyperparameters etc.

**2-** We initialize the execution device (GPU, if not, CPU), and obtains the command line arguments using the 'get_args()' method.

The model is created, and training is initiated by calling the 'train' function


In [None]:
def get_args():
    parser = argparse.ArgumentParser(description='ANOMALYDETECTION')

    # General settings
    parser.add_argument('--phase', choices=['train', 'test'], default='train')
    parser.add_argument('--dataset_path', default='../dbt_dataset')
    parser.add_argument('--category', default='dbt')
    parser.add_argument('--batch_size', type=int, default=300)
    parser.add_argument('--load_size', default=256)  # 256
    parser.add_argument('--input_size', default=256)
    parser.add_argument('--coreset_sampling_ratio', default=0.01)
    parser.add_argument('--project_root_path', default='results')
    parser.add_argument('--save_src_code', default=True)
    parser.add_argument('--save_anomaly_map', default=True)

    # Model hyperparameters
    parser.add_argument('--n_neighbors', type=int, default=9)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--k', type=int, default=9)
    parser.add_argument('--learning-rate-weights', default=0.01, type=float, metavar='LR',
                        help='base learning rate for weights')
    parser.add_argument('--learning-rate-biases', default=0.0048, type=float, metavar='LR',
                        help='base learning rate for biases and batch norm parameters')
    parser.add_argument('--weight_decay', type=float, default=1e-6)

    # Training settings
    parser.add_argument('--epochs', type=int, default=10000)
    parser.add_argument('--patch_size', type=int, default=128)
    parser.add_argument('--step_size', type=int, default=32)
    parser.add_argument('--use_tumor', type=int, default=0)

    args = parser.parse_args()
    return args

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args = get_args()

    model = Patch_Model(input_channel=3)
    model.to(device)
    train(model, device, args)

To sum up, this script provides a framework for training anomaly detection models on medical image datasets using PyTorch. We separated functions for dataset creation, data loading, model evaluation and training loop; providing a modular implementation.
We think this is a comprehensive pipeline for training an anomaly detection.


Carlos Pino Padilla and Carlos Ramírez Rodríguez de Sepúlveda.
Neural Networks, Università la Sapienza.