In [None]:
import torch
import gc
import time
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import f1_score
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup
import sys
sys.path.append('/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/src/codebase/')

from Classifiers.models.breast_clip_classifier import BreastClipClassifier
from Datasets.dataset_utils import get_dataloader_RSNA
from breastclip.scheduler import LinearWarmupCosineAnnealingLR
from metrics import pfbeta_binarized, pr_auc, compute_auprc, auroc, compute_accuracy_np_array
from utils import seed_all, AverageMeter, timeSince
from breastclip.model.modules import load_image_encoder, LinearClassifier



## Initialize

In [None]:
class Args:
    def __init__(self):
        self.tensorboard_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/log'
        self.checkpoints = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/checkpoints'
        self.output_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/out'
        self.data_dir = '/restricted/projectnb/batmanlab/shared/Data/RSNA_Breast_Imaging/Dataset'
        self.img_dir = 'RSNA_Cancer_Detection/train_images_png'
        self.clip_chk_pt_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Breast-CLIP/src/codebase/outputs/upmc_clip/b5_detector_period_n/checkpoints/fold_0/b5-model-best-epoch-7.tar'
        self.csv_file = 'RSNA_Cancer_Detection/train_folds.csv'
        self.dataset = 'RSNA'
        self.data_frac = 1.0
        self.arch = 'upmc_breast_clip_det_b5_period_n_ft'
        self.label = 'cancer'
        self.detector_threshold = 0.1
        self.swin_encoder = 'microsoft/swin-tiny-patch4-window7-224'
        self.pretrained_swin_encoder = 'y'
        self.swin_model_type = 'y'
        self.VER = '084'
        self.epochs_warmup = 0
        self.num_cycles = 0.5
        self.alpha = 10
        self.sigma = 15
        self.p = 1.0
        self.mean = 0.3089279
        self.std = 0.25053555408335154
        self.focal_alpha = 0.6
        self.focal_gamma = 2.0
        self.num_classes = 1
        self.n_folds = 4
        self.start_fold = 0
        self.seed = 10
        self.batch_size = 1
        self.num_workers = 4
        self.epochs = 9
        self.lr = 5.0e-5
        self.weight_decay = 1e-4
        self.warmup_epochs = 1
        self.img_size = [1520, 912]
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.apex = 'y'
        self.print_freq = 5000
        self.log_freq = 1000
        self.running_interactive = 'n'
        self.inference_mode = 'n'
        self.model_type = "Classifier"
        self.weighted_BCE = 'n'
        self.balanced_dataloader = 'n'

# Create an instance of the Args class
args = Args()

# Now you can use args just like you would in your script
print(args.tensorboard_path) 
# output: /restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/log

## Load data

In [None]:
args.model_base_name = 'efficientnetb5'
args.data_dir = Path(args.data_dir)
args.df = pd.read_csv(args.data_dir / args.csv_file)
args.df = args.df.fillna(0)
args.cur_fold = 0
args.train_folds = args.df[
                (args.df['fold'] == 1) | (args.df['fold'] == 2)].reset_index(drop=True)
args.valid_folds = args.df[args.df['fold'] == args.cur_fold].reset_index(drop=True)

print(f"train_folds shape: {args.train_folds.shape}")
print(f"valid_folds shape: {args.valid_folds.shape}")
# output: train_folds shape: (27258, 15)
# output: valid_folds shape: (13682, 15)

ckpt = torch.load(args.clip_chk_pt_path, map_location="cpu")
args.image_encoder_type = ckpt["config"]["model"]["image_encoder"]["name"]
train_loader, valid_loader = get_dataloader_RSNA(args)
print(f'train_loader: {len(train_loader)}, valid_loader: {len(valid_loader)}')
# output: Compose([
#   HorizontalFlip(p=0.5),
#   VerticalFlip(p=0.5),
#   Affine(p=0.5, interpolation=1, mask_interpolation=0, cval=0.0, mode=0, scale={'x': (0.8, 1.2), 'y': (0.8, 1.2)}, translate_percent={'x': (0.1, 0.1), 'y': (0.1, 0.1)}, translate_px=None, rotate=(20.0, 20.0), fit_output=False, shear={'x': (20.0, 20.0), 'y': (20.0, 20.0)}, cval_mask=0.0, keep_ratio=False, rotate_method='largest_box', balanced_scale=False),
#   ElasticTransform(p=0.5, alpha=10.0, sigma=15.0, interpolation=1, border_mode=4, value=None, mask_value=None, approximate=False, same_dxdy=False),
# ], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}, is_check_shapes=True)
# None
# train_loader: 3407, valid_loader: 1711

## Load Mammo-CLIP

In [None]:
n_class = 1
print(ckpt["config"]["model"]["image_encoder"])
config = ckpt["config"]["model"]["image_encoder"]
image_encoder = load_image_encoder(ckpt["config"]["model"]["image_encoder"])
image_encoder_weights = {}
for k in ckpt["model"].keys():
    if k.startswith("image_encoder."):
        image_encoder_weights[".".join(k.split(".")[1:])] = ckpt["model"][k]
image_encoder.load_state_dict(image_encoder_weights, strict=True)
image_encoder_type = ckpt["config"]["model"]["image_encoder"]["model_type"]
image_encoder = image_encoder.to(args.device)

print(image_encoder_type)
print(config["name"].lower()) 
# cnn
# tf_efficientnet_b5_ns-detect

## Loop thorough the data

In [None]:
progress_iter = tqdm(enumerate(valid_loader), desc=f"[tutorial]",
                     total=len(valid_loader))
for step, data in progress_iter:
    inputs = data['x'].to(args.device)
    inputs = inputs.squeeze(1).permute(0, 3, 1, 2)
    batch_size = inputs.size(0)
    image_features = image_encoder(inputs)
    print(image_features.shape)
    break
    # torch.Size([1, 2048])