In [None]:
from tqdm import tqdm
import argparse
import os
import pickle


In [None]:
import os
from pathlib import Path

import torch

from breastclip.data import DataModule
from breastclip.model import build_model

ATTRIBUTES = [
    "mass",
    "suspicious_calcification",
]


def generate_attribute_embs(out_dir, breast_clip_path, model):
    """
    Generate embeddings for each attribute.

    Parameters:
        out_dir: Directory for storing attribute embeddings
    """

    def get_prompts(attr):
        if attr in ["mass"]:
            prompts = [
                # Replace with either the corresponding prompts or the sentences from the report.
            ]

        elif attr in ["suspicious_calcification"]:
            prompts = [
                # Replace with either the corresponding prompts or the sentences from the report.
            ]
    
        return prompts

    # model, preprocess = clip.load("RN50", "cuda")

    # Obtain prompts for each attribute
    prompt_list = []
    for attr in ATTRIBUTES:
        print(attr)
        prompt_list.append(get_prompts(attr))

    print(len(prompt_list))
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    breast_clip_path = Path(breast_clip_path)
    print(breast_clip_path)
    configs = torch.load(breast_clip_path, map_location=device)
    cfg = configs["config"]

    datamodule = DataModule(
        data_config=cfg["data_train"],
        dataloader_config=cfg["dataloader"],
        tokenizer_config=cfg["tokenizer"] if "tokenizer" in cfg else None,
        loss_config=cfg["loss"],
        transform_config=cfg["transform"],
        mean=cfg["base"]["mean"],
        std=cfg["base"]["std"],
        image_encoder_type=cfg["model"]["image_encoder"]["model_type"],
        cur_fold=cfg["base"]["fold"]
    )

    clip = build_model(cfg["model"], cfg["loss"], datamodule.tokenizer)
    clip = clip.to(device)
    clip.load_state_dict(configs["model"], strict=True)
    clip.eval()
    print(clip)
    # Compute each attribute embedding as the average of its associated prompt embeddings
    attr_embs = []
    with torch.no_grad():
        for prompt in prompt_list:
            text_token = datamodule.tokenizer(
                prompt, padding="longest", truncation=True, return_tensors="pt", max_length=256
            )
            text_emb = clip.encode_text(text_token.to(device))
            text_emb = clip.text_projection(text_emb) if clip.projection else text_emb
            text_emb = text_emb.mean(dim=0, keepdim=True)
            text_emb /= text_emb.norm(dim=-1, keepdim=True)
            attr_embs.append(text_emb)

        attr_embs = torch.stack(attr_embs).squeeze().detach().cpu().numpy()

    attr_to_emb = dict(zip(ATTRIBUTES, attr_embs))
    print(attr_embs.shape)
    out_dir = Path(out_dir)
    os.makedirs(out_dir, exist_ok=True)
    torch.save(attr_to_emb, f"{out_dir}/attr_embs_{model}.pth")
    print(f"Saved {len(attr_to_emb)} attribute embeddings to {out_dir}/attr_embs_{model}.pth")


In [None]:
import torch

from breastclip.model.modules import load_image_encoder


class Mapper_model(torch.nn.Module):
    def __init__(self, ckpt, lang_emb: int, emb_dim: int, one_proj: bool, adapter: bool, attr_embs):
        super(Mapper_model, self).__init__()
        self.image_encoder = load_image_encoder(ckpt["config"]["model"]["image_encoder"])
        image_encoder_weights = {}
        for k in ckpt["model"].keys():
            if k.startswith("image_encoder."):
                image_encoder_weights[".".join(k.split(".")[1:])] = ckpt["model"][k]
        self.image_encoder.load_state_dict(image_encoder_weights, strict=True)
        self.image_encoder_type = ckpt["config"]["model"]["image_encoder"]["model_type"]
        for param in self.image_encoder.parameters():
            param.requires_grad = False

        self.emb_dim = emb_dim
        self.lang_emb = lang_emb
        self.one_proj = one_proj
        self.adapter = adapter

        # Initialize projection heads
        if self.one_proj:
            self.num_proj = 1
        else:
            self.num_proj = len(attr_embs)
        self.pool = torch.nn.ModuleList(
            [
                torch.nn.Sequential(
                    torch.nn.Linear(self.emb_dim, self.emb_dim),
                    torch.nn.ReLU(),
                    torch.nn.Linear(self.emb_dim, self.lang_emb),
                )
                for i in range(self.num_proj)
            ]
        )

    def encode_image(self, input):
        image_features, raw_features = self.image_encoder(input)
        return image_features, raw_features

    def forward(self, sample: dict):
        out_dict = {}

        img_vector = sample["img"].to(torch.float32).to("cuda")
        if len(img_vector.size()) == 5:
            img_vector = img_vector.squeeze(1).permute(0, 3, 1, 2)
        input = {"image": img_vector}

        image_features, raw_features = self.encode_image(input)
        bs = raw_features.size(0)
        channel_dim = raw_features.size(1)
        raw_features_flatten = raw_features.view(bs, channel_dim, -1)

        out_img_a = []
        for i in range(self.num_proj):
            pool = self.pool[i](raw_features_flatten)
            if self.adapter:
                pool = 0.2 * pool + 0.8 * raw_features_flatten
            out_img_a.append(pool)

        region_proj_embs = torch.cat(out_img_a, dim=1).view(
            -1, self.num_proj, self.lang_emb
        )
        out_dict["region_proj_embs"] = region_proj_embs
        out_dict["num_regions"] = torch.tensor(channel_dim)
        out_dict["image_features"] = image_features
        out_dict["raw_features"] = raw_features
        return out_dict


In [None]:
import torch
import numpy as np
import torch
import abc


class BaseLoss(torch.nn.Module):
    __metaclass__ = abc.ABC

    def __init__(self):
        super().__init__()

        self.iteration = 0
        self.running_loss = 0
        self.mean_running_loss = 0

    def forward(self, input):
        return input

    def update_running_loss(self, loss):
        self.iteration += 1
        self.running_loss += loss.item()
        self.mean_running_loss = self.running_loss / self.iteration


class Mapper_loss(BaseLoss):
    def __init__(self, temp: float, one_proj: bool, attr_to_embs):
        super().__init__()

        # Initialize class variables
        self.temperature = temp
        self.one_proj = one_proj

        # Load attribute embeddings
        self.attr_embs = []
        for a in attr_to_embs:
            self.attr_embs.append(attr_to_embs[a])
        self.attr_embs = torch.tensor(np.stack(self.attr_embs)).cuda().to(torch.float32)

    def forward(self, pred: dict, sample: dict):
        anchor_img = torch.nn.functional.normalize(
            pred["region_proj_embs"].float(), dim=2
        )
        labels = sample["labels"].to(int)

        # Get the indexes of the concepts which is at least present, for ex. mass and calcificaation
        attr_ids = labels.sum(0).nonzero().flatten().tolist()
        batch_size = labels.shape[0]

        loss = torch.tensor(0.0).cuda()
        num_loss_terms = 0

        # Compute region-attribute similarity matrix by calculating the similarity between
        # each attribute and the region embedding resulting from the corresponding projection head
        if self.one_proj:
            txt_emb = self.attr_embs[attr_ids, :].T.unsqueeze(0)
            sim = (anchor_img @ txt_emb).squeeze() / self.temperature
        else:
            reg_emb = anchor_img[:, attr_ids, :]
            txt_emb = self.attr_embs[attr_ids, :].unsqueeze(0)
            sim = (reg_emb * txt_emb).sum(2) / self.temperature

        # Convert region-attribute similarity matrix into image-attribute similarity matrix by
        # computing the maximum pairwise similarity between all regions in the image and each attribute
        split = torch.split(sim, pred["num_regions"].tolist(), dim=0)
        vals, _ = zip(*map(torch.max, split, [0] * batch_size))

        sim = torch.stack(vals)  # Size: batch_size x len(attr_ids)
        true_label = labels[:, attr_ids].cuda()
        inv_true_label = (~true_label.bool()).to(int)

        # Compute final contrastive loss
        denom = torch.exp(sim) + torch.exp(sim * inv_true_label).sum(0, keepdims=True)
        loss = ((-torch.log(torch.exp(sim) / denom)) * true_label).sum(1, keepdims=True)
        num_loss_terms = true_label.sum()
        loss = loss.sum() / num_loss_terms
        self.update_running_loss(loss)
        return loss


In [None]:
def get_dataloader_Vindr(args):
    train_dataset = MammoDataset_Mapper(args=args, df=args.train_folds, transform=get_transforms(args))
    valid_dataset = MammoDataset_Mapper(args=args, df=args.valid_folds)

    if args.balanced_dataloader == "y":
        weight_path = args.output_path / f"random_sampler_weights_fold{str(args.cur_fold)}.pkl"
        if weight_path.exists():
            weights = pickle.load(open(weight_path, "rb"))
        else:
            weight_for_positive_class = args.sampler_weights[f"fold{str(args.cur_fold)}"]["pos_wt"]
            weight_for_negative_class = args.sampler_weights[f"fold{str(args.cur_fold)}"]["neg_wt"]
            args.train_folds["weights_random_sampler"] = args.train_folds.apply(
                lambda row: weight_for_positive_class if row["cancer"] == 1 else weight_for_negative_class, axis=1
            )
            weights = args.train_folds["weights_random_sampler"].values
            pickle.dump(weights, open(args.output_path / f"random_sampler_weights_fold{args.cur_fold}.pkl", "wb"))

        weights = weights.tolist()
        sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
        train_loader = DataLoader(
            train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True,
            drop_last=True, sampler=sampler
        )
    else:
        train_loader = DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True,
            drop_last=True,
        )

    valid_loader = DataLoader(
        valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True,
        drop_last=False
    )

    return train_loader, valid_loader

In [None]:
def train_region_mapper(args, device):
    print(f"=> Training is getting started")
    # Initialize dataloaders
    args.data_dir = Path(args.data_dir)
    args.df = pd.read_csv(args.data_dir / args.csv_file)
    args.df = args.df.fillna(0)
    args.df = args.df[(args.df["Mass"] == 1) | (args.df["Suspicious_Calcification"] == 1)]
    print(f"df shape: {args.df.shape}")
    print(args.df.columns)
    args.train_folds = args.df[args.df['split'] == "training"].reset_index(drop=True)
    args.valid_folds = args.df[args.df['split'] == "test"].reset_index(drop=True)
    train_loader, val_loader = get_dataloader_Vindr(args)
    print(f'train_loader: {len(train_loader)}, valid_loader: {len(val_loader)}')

    print(args.clip_chk_pt_path)

    ckpt = torch.load(args.clip_chk_pt_path, map_location="cpu")

    # Initialize model, loss, and optimizer
    attr_embs = torch.load(args.attr_embs_path)
    model = Mapper_model(
        ckpt, lang_emb=args.lang_emb, emb_dim=args.img_emb, one_proj=False, adapter=False, attr_embs=attr_embs
    ).to("cuda")
    loss = Mapper_loss(temp=0.07, one_proj=False, attr_to_embs=attr_embs)

    opt = torch.optim.AdamW(model.parameters(), lr=args.lr)
    print(
        f"=> Using model {type(model).__name__} with loss {type(loss).__name__} on {torch.cuda.device_count()} GPUs"
    )

    # Train Mapper model
    print(f"=> Training Mapper model")
    train(
        model,
        loss,
        opt,
        None,
        train_loader,
        val_loader,
        args.epochs,
        args.batch_size,
        device,
        args.chk_pt_path,
        False,
    )


def train(
        model,
        loss_fn,
        opt,
        scheduler,
        train_loader,
        val_loader,
        epochs,
        batch_size,
        device,
        chk_pt_path,
        early_stop=False,
):
    """
    Training loop.

    Parameters:
        model (torch.nn.Module): Model
        loss_fn (torch.nn.Module): Loss function
        opt (torch.optim): Optimizer
        scheduler (torch.optim.lr_scheduler): LR scheduler (set to None if no scheduler)
        train_loader (torch.utils.data.DataLoader): Dataloader for training data
        val_loader (torch.utils.data.DataLoader): Dataloader for validation data
        epochs (int): Number of training epochs
        batch_size (int): Batch size
        checkpoint_dir (str): Directory for storing model weights
        early_stop (bool): True if early stopping based on validation loss
    """
    scaler = torch.cuda.amp.GradScaler()
    epochs_no_improvements = 0
    best_val_loss = np.inf

    for epoch in range(0, epochs):
        model.train()
        time_start = time.time()
        progress_iter = tqdm(enumerate(train_loader), desc=f"[{epoch + 1:03d}/{epochs:03d} epoch train]",
                             total=len(train_loader))
        for step, sample in progress_iter:
            opt.zero_grad()

            with torch.cuda.amp.autocast(enabled=True):
                pred = model(sample)

            loss = loss_fn(pred, sample)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()

            # summary = (
            #     "\r[Epoch {}][Step {}/{}] Loss: {}, Lr: {} - {:.2f} m remaining".format(
            #         epoch + 1,
            #         step,
            #         int(len(train_loader.dataset) / batch_size),
            #         "{}: {:.2f}".format(
            #             type(loss_fn).__name__, loss_fn.mean_running_loss
            #         ),
            #         *[group["lr"] for group in opt.param_groups],
            #         ((time.time() - time_start) / (step + 1))
            #         * ((len(train_loader.dataset) / batch_size) - step)
            #         / 60,
            #     )
            # )
            # print(summary)

            progress_iter.set_postfix(
                {
                    "lr": [opt.param_groups[0]['lr']],
                    "loss": f"{loss_fn.mean_running_loss:.4f}",
                    "CUDA-Mem": f"{torch.cuda.memory_usage(device)}%",
                    "CUDA-Util": f"{torch.cuda.utilization(device)}%",
                }
            )
        time_end = time.time()
        elapse_time = time_end - time_start
        print("Finished in {}s".format(int(elapse_time)))

        torch.save(model.state_dict(), os.path.join(chk_pt_path, f"epoch_{epoch + 1}.pkl"))

        val_loss = evaluate(model, epoch, epochs, loss_fn, val_loader, device)
        epochs_no_improvements += 1
        if val_loss < best_val_loss:
            print("Saving best model")
            torch.save(
                model.state_dict(), os.path.join(chk_pt_path, f"best.pkl")
            )
            epochs_no_improvements = 0
            best_val_loss = val_loss

        if scheduler:
            scheduler.step(val_loss)

        if epochs_no_improvements == 5:
            print("Early stop reached")
            return


def evaluate(model, epoch, epochs, loss_fn, val_loader, device, split="val"):
    """
    Validation loop.

    Parameters:
        model (torch.nn.Module): Model
        loss_fn (torch.nn.Module): Loss function
        val_loader (torch.utils.data.DataLoader): Dataloader for validation data
        split (str): Evaluation split
    Returns:
        loss (torch.Tensor): Validation loss
    """
    print(f"Evaluating on {split}")
    model.eval()

    running_loss = 0
    num_batches = 0
    with torch.no_grad():
        progress_iter = tqdm(enumerate(val_loader), desc=f"[{epoch + 1:03d}/{epochs:03d} epoch train]",
                             total=len(val_loader))
        for step, sample in progress_iter:
            pred = model(sample)

            running_loss += loss_fn(pred, sample)
            num_batches += 1

            progress_iter.set_postfix(
                {
                    "loss": f"{running_loss:.4f}",
                    "CUDA-Mem": f"{torch.cuda.memory_usage(device)}%",
                    "CUDA-Util": f"{torch.cuda.utilization(device)}%",
                }
            )

    loss = running_loss / num_batches
    print(f"Eval Loss = {loss}")
    return loss

### Step1: generate the report-sentence or prompt embedding which will localize the finding (mass)

#### Replace `data_dir`, `breast_clip_path` and `model` based on your paths

In [None]:
data_dir = "<Path where the sentence/prompt embeddings will be saved>"
breast_clip_path = "<Mammo-CLIP checkpoint path>"
model = "b5_all_clip_mapper"
generate_attribute_embs(data_dir, breast_clip_path, model)

### Step2: Train Mammo-Factor
#### 1. Get the path of `args` uploaded [here](https://github.com/batmanlab/Mammo-CLIP/blob/main/src/codebase/notebooks/Mammo-Factor/seed_10_train_configs.pkl)
#### Replace `args.clip_chk_pt_path` and `args.attr_embs_path` as per covenience.
####  `args.attr_embs_path` is the path of the sentence/prompt embedding generated in Stage1.

In [None]:
args = pickle.load(
        open(
            "seed_10_train_configs.pkl", 
            "rb"))
args.root = f"lr_{args.lr}_epochs_{args.epochs}"
args.apex = True if args.apex == "y" else False
args.running_interactive = True if args.running_interactive == "y" else False

## Change this two
args.arch = model
args.clip_chk_pt_path = "<Mammo-CLIP checkpoint path>"
args.attr_embs_path = "<First extract the embeddings of the report sentences and give path to that embedding file>"

chk_pt_path, output_path, tb_logs_path = get_Paths(args)
args.chk_pt_path = chk_pt_path
args.output_path = output_path
args.tb_logs_path = tb_logs_path

os.makedirs(chk_pt_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)
os.makedirs(tb_logs_path, exist_ok=True)
print("====================> Paths <====================")
print(f"checkpoint_path: {chk_pt_path}")
print(f"output_path: {output_path}")
print(f"tb_logs_path: {tb_logs_path}")
print('device:', device)
print('torch version:', torch.__version__)
print("====================> Paths <====================")

pickle.dump(args, open(os.path.join(output_path, f"seed_{args.seed}_train_configs.pkl"), "wb"))
train_region_mapper(args, device)