In [1]:
from transformers import BertTokenizer
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
import cv2
import re
from tqdm import tqdm
from collections import defaultdict
import random
import math
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.utils.data.sampler import Sampler
import os
from concurrent.futures import ThreadPoolExecutor

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torchvision
from torchvision import transforms

from transformers import AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Read only the first few columns (e.g., first 3 columns) from a CSV file
df = pd.read_csv('bahar_data_1.csv', low_memory=False)

In [4]:
def remove_empty_reports(df):
    indices = []

    for i, report in enumerate(df['report_page1']):
        if pd.isna(report) or report == "b''" or report == "" or report.strip() == "":
            indices.append(i)

    df.drop(indices, axis=0, inplace=True)
    df.reset_index(drop=True, inplace=True) 
    print(f'no. of empty reports: {len(indices)}, patient_indices: {indices}')

remove_empty_reports(df)

no. of empty reports: 8, patient_indices: [19609, 34167, 55005, 81560, 102207, 122282, 142499, 170474]


In [5]:
def extract_findings_section(df):
    empty_extractions = 0

    extracted_findings_list = []
    bad_indices = []

    for i, report in enumerate(df["report_page1"]):
        match = re.search(r'(.+)(?:Measurements and Calculations:)', report, re.DOTALL | re.IGNORECASE)
        if match:
            findings = match.group(1)
            
            # Remove study info line
            match = re.search(r'Study Info:.*', findings, re.IGNORECASE)
            if match:
                findings = findings[:match.start()] + findings[match.end():]

            # Remove comparison to previous exam line
            match = re.search(r'Comparison to Previous Exam.*|Compared to prior exam.*|Compared to previous study.*', findings, re.IGNORECASE)
            if match:
                findings = findings[:match.start()] + findings[match.end():]

            cleaned_lines = [line.strip() for line in findings.splitlines() if line.strip()]
            
            extracted_findings_list.append('\n'.join(cleaned_lines))
        else:
            match2 = re.search(r'(.+)(?:Comparison to Previous Exam)|(.+)(?:Compared to prior exam)|(.+)(?:Compared to previous study)', report, re.DOTALL | re.IGNORECASE)

            if match2:
                findings = match2.group(1)
                # Remove study info line
                match = re.search(r'Study Info:.*', findings, re.IGNORECASE)
                if match:
                    findings = findings[:match.start()] + findings[match.end():]

                cleaned_lines = [line.strip() for line in findings.splitlines() if line.strip()]
                
                extracted_findings_list.append('\n'.join(cleaned_lines))
            else:
                empty_extractions += 1
                extracted_findings_list.append(-1)
                bad_indices.append(df['report_page1'].index[i])
    
    print(f"Empty extractions: {empty_extractions} Extraction percentage: {empty_extractions/len(df)*100}")
    print(f"Bad reports: {bad_indices}")
    df["extracted_findings"] = extracted_findings_list

extract_findings_section(df)

Empty extractions: 44 Extraction percentage: 0.02382989785639237
Bad reports: [2385, 16181, 19888, 20265, 20872, 21786, 22032, 22825, 24556, 24851, 36989, 51432, 55679, 56322, 57300, 57554, 58374, 58375, 60274, 60596, 79225, 82190, 83027, 99678, 100366, 102916, 103851, 108541, 120014, 122902, 123725, 140015, 140691, 143195, 144117, 148726, 154134, 167319, 171077, 171710, 172675, 172926, 173688, 175562]


In [6]:
def extract_calculations_section(df):
    empty_extractions = 0

    extracted_calcs_list = []
    bad_indices = []

    for i, report in enumerate(df["report_page1"]):
        match = re.search(r'(?:Measurements and Calculations:)(.+?)(?:Sonographer|$)', report, re.DOTALL | re.IGNORECASE)
        
        if match:
            findings = match.group(1)
            cleaned_lines = [line.strip() for line in findings.splitlines() if line.strip()]
            extracted_calcs_list.append('\n'.join(cleaned_lines))
        else:
            match2 = re.search(r'(RVd\s+A4C:.+?)(?:Sonographer|$)', report, re.DOTALL | re.IGNORECASE)

            if match2:
                findings = match2.group(1)
                cleaned_lines = [line.strip() for line in findings.splitlines() if line.strip()]
                extracted_calcs_list.append('\n'.join(cleaned_lines))
            else:
                empty_extractions += 1
                extracted_calcs_list.append(-1)
                bad_indices.append(df['report_page1'].index[i])
    
    print(f"Empty extractions: {empty_extractions} Extraction percentage: {empty_extractions/len(df)*100}")
    print(f"Bad reports: {bad_indices}")
    df["extracted_calcs"] = extracted_calcs_list

extract_calculations_section(df)

Empty extractions: 71 Extraction percentage: 0.03845278972281496
Bad reports: [2385, 8450, 10295, 18240, 19888, 20265, 20872, 21786, 22032, 22825, 24556, 24851, 28956, 36989, 40182, 43473, 45364, 53577, 55679, 56322, 57300, 57554, 58374, 58375, 60274, 60596, 64890, 65524, 72177, 74071, 82190, 83027, 86264, 86675, 91990, 94037, 100366, 102916, 103851, 107422, 107868, 108541, 108788, 113166, 115015, 122902, 123725, 126900, 127295, 132497, 134500, 140691, 143195, 144117, 147632, 148069, 148726, 148972, 154134, 157254, 160289, 161944, 169186, 171077, 171710, 172675, 172926, 173688, 175562, 179771, 180329]


In [7]:
def determine_report_usability(df):
    usable_list = []

    for i, row in df.iterrows():
        if row["extracted_calcs"] != -1 and row["extracted_findings"] != -1:
            usable_list.append(True)
        else:
            usable_list.append(False)

    df["data_usable"] = usable_list

def drop_unusable_reports(df):
    indices =[]

    for i, usable in enumerate(df["data_usable"]):
        if not usable:
            indices.append(i)

    df.drop(indices, axis=0, inplace=True)
    df.reset_index(drop=True, inplace=True) 
    print(f'no. of unusable reports: {len(indices)}, patient_indices: {indices}')
    

determine_report_usability(df)
drop_unusable_reports(df)

no. of unusable reports: 78, patient_indices: [2385, 8450, 10295, 16181, 18240, 19888, 20265, 20872, 21786, 22032, 22825, 24556, 24851, 28956, 36989, 40182, 43473, 45364, 51432, 53577, 55679, 56322, 57300, 57554, 58374, 58375, 60274, 60596, 64890, 65524, 72177, 74071, 79225, 82190, 83027, 86264, 86675, 91990, 94037, 99678, 100366, 102916, 103851, 107422, 107868, 108541, 108788, 113166, 115015, 120014, 122902, 123725, 126900, 127295, 132497, 134500, 140015, 140691, 143195, 144117, 147632, 148069, 148726, 148972, 154134, 157254, 160289, 161944, 167319, 169186, 171077, 171710, 172675, 172926, 173688, 175562, 179771, 180329]


In [None]:
def run_calc_searches(example_text):
    la_mm = re.search(r"left\s+atrium:\s+[*]?(\d+)\s+mm", example_text, flags = re.IGNORECASE)
    ivsd = re.search(r"ivsd:\s+[*]?(\d+)\s+mm", example_text, flags = re.IGNORECASE)
    lvpwd = re.search(r"lvpwd:\s+[*]?(\d+)\s+mm", example_text, flags = re.IGNORECASE)
    la_vol_bpidx = re.search(r"la\s?(?:volume|vol)\s?index\s?(?:\(Biplane\))?:\s+[*]?([\d\.]+)\s+ml/m", example_text, flags = re.IGNORECASE)
    lvidd = re.search(r"lvidd:\s+[*]?(\d+)", example_text, flags = re.IGNORECASE)
    lvidd_idx = re.search(r"lvidd\sindex:\s+[*]?(\d+)\s+mm/m", example_text, flags = re.IGNORECASE)
    rvd_a4c = re.search(r"rvd\sa4c:\s+[*]?(\d+)\s+mm", example_text, flags = re.IGNORECASE)
    lvids = re.search(r"lvids:\s+[*]?(\d+)\s+mm", example_text, flags = re.IGNORECASE)
    rv_s = re.search(r"rv\sS.\s+[*]?([\d\.]+)\s+cm/s", example_text, flags = re.IGNORECASE)
    lvef = re.search(r"lv\s?ef\s?.(?:Biplane|Visual).:\s+[*]?(\d+)\s?%?", example_text, flags = re.IGNORECASE)
    tapse = re.search(r"tapse:\s+[*]?(\d+)\s+mm", example_text, flags = re.IGNORECASE)
    lv_mass_idx = re.search(r"lv\smass\sindex:\s+[*]?([\d\.]+)\s+g/m", example_text, flags = re.IGNORECASE)   
    lv_rwt = re.search(r"lv\s?rwt:\s+[*]?([\d\.]+)", example_text, flags = re.IGNORECASE)
    ra_vol_idx = re.search(r"ra\s?vol[\w|\s]+index:\s+[*]?(\d+)\s+ml/m", example_text, flags = re.IGNORECASE)
    lv_edv_idx = re.search(r"lv\s?edv\s?index:\s+[*]?([\d\.]+)\s+ml/m", example_text, flags = re.IGNORECASE)
    lv_esv_idx = re.search(r"lv\s?esv\s?index:\s+[*]?([\d\.]+)\s+ml/m", example_text, flags = re.IGNORECASE)
    aorta_sinuses = re.search(r"aorta\ssinuses:\s+[*]?(\d+)\s+mm", example_text, flags = re.IGNORECASE)
    lvot_diam = re.search(r"lvot\sdiam:\s+[*]?(\d+)\s+mm", example_text, flags = re.IGNORECASE)
    aorta_sinuses_idx = re.search(r"aorta\ssinuses\sindex:\s+[*]?([\d\.]+)\s+mm/m", example_text, flags = re.IGNORECASE)
    prox_asc_aorta = re.search(r"prox\sascending\saorta:\s+[*]?(\d+)\s+mm", example_text, flags = re.IGNORECASE)
    prox_asc_aorta_idx = re.search(r"prox\sasc\saorta\sindex:\s+[*]?([\d\.]+)\s+mm/m", example_text, flags = re.IGNORECASE)
    mv_peak_e = re.search(r"mv\speak\se:\s+[*]?([\d\.]+)\s+cm/s", example_text, flags = re.IGNORECASE)
    mv_peak_a = re.search(r"mv\speak\sa:\s+[*]?([\d\.]+)\s+cm/s", example_text, flags = re.IGNORECASE)
    mv_ea_ratio = re.search(r"mv\se/a\sratio:\s+[*]?([\d\.]+)", example_text, flags = re.IGNORECASE)
    decel_time = re.search(r"decel\stime:\s+[*]?(\d+)\s+msec", example_text, flags = re.IGNORECASE)
    lateral_e = re.search(r"lateral\se\s?:\s+[*]?([\d\.]+)\s+cm/s", example_text, flags = re.IGNORECASE)
    septal_e = re.search(r"septal\se\s?:\s+[*]?([\d\.]+)\s+cm/s", example_text, flags = re.IGNORECASE)
    avg_ee_ratio = re.search(r"average\se/e\sratio:\s+[*]?([\d\.]+)", example_text, flags = re.IGNORECASE)
    tr_max_velocity = re.search(r"TR\s+max\s+velocity:\s+[*]?([\d\.]+)\s+m/s", example_text, flags = re.IGNORECASE)
    ra_pressure = re.search(r"ra\spressure:?\s+[*]?([\d\.]+)\s+mmHg", example_text, flags = re.IGNORECASE)
    pasp = re.search(r"pasp:\s+[*]?(\d+)\s+mmHg", example_text, flags = re.IGNORECASE)


    measurements = [la_mm, la_vol_bpidx,
                   lvpwd, lvidd, lvidd_idx, lvef, lv_mass_idx, lv_rwt, lv_edv_idx, lv_esv_idx, lvot_diam,
                   ivsd, lvids,
                   rvd_a4c, rv_s,
                    tapse,
                   ra_vol_idx,
                    aorta_sinuses, aorta_sinuses_idx, prox_asc_aorta, prox_asc_aorta_idx,
                   mv_peak_e, mv_peak_a, mv_ea_ratio,
                    decel_time,
                    lateral_e, septal_e,
                    avg_ee_ratio,
                    tr_max_velocity,
                    ra_pressure,
                    pasp]
    
    measurements_ls = [m.group(1) if m !=None else -1 for m in measurements]
    
    return measurements_ls   

print(run_calc_searches(df["extracted_calcs"].iloc[16]))


In [8]:
def crop_and_scale(img, res=(224, 224), interpolation=cv2.INTER_CUBIC, zoom=0.1):
    in_res = (img.shape[1], img.shape[0])
    r_in = in_res[0] / in_res[1]
    r_out = res[0] / res[1]

    if r_in > r_out:
        padding = int(round((in_res[0] - r_out * in_res[1]) / 2))
        if padding > 0:
            img = img[:, padding:-padding]
    if r_in < r_out:
        padding = int(round((in_res[1] - in_res[0] / r_out) / 2))
        if padding > 0:
            img = img[padding:-padding]
    if zoom != 0:
        pad_x = round(int(img.shape[1] * zoom))
        pad_y = round(int(img.shape[0] * zoom))
        if pad_y * 2 < img.shape[0] and pad_x * 2 < img.shape[1]:
            img = img[pad_y:img.shape[0]-pad_y, pad_x:img.shape[1]-pad_x]

    img = cv2.resize(img, res, interpolation=interpolation)
    return img

In [9]:
val_percent = 1/400
test_percent = 1/100
train_percent = 1 - val_percent - test_percent

grouped = df.groupby("patient_id").first().reset_index()
patients = grouped["patient_id"].sample(frac=1, random_state=42).reset_index(drop=True)

train_index = int(train_percent * len(patients))
val_index = int(val_percent * len(patients))

train_patients = patients[:train_index]
val_patients = patients[train_index:train_index + val_index]
test_patients = patients[train_index + val_index:]

print(f"Train Patients: {len(train_patients)}, Val Patients: {len(val_patients)}, Test Patients: {len(test_patients)}")

df_train = df[df["patient_id"].isin(train_patients)]
df_val = df[df["patient_id"].isin(val_patients)]
df_test = df[df["patient_id"].isin(test_patients)]

model_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract"
#model_name = "yikuan8/Clinical-Longformer"
tokenizer = BertTokenizer.from_pretrained(model_name)

Train Patients: 34363, Val Patients: 86, Test Patients: 349


In [10]:
def get_report_vid_pairs(df):

    # Group by exam_id and get the first report per exam
    grouped = df.groupby("exam_id")["extracted_findings"].first().reset_index()

    batch_size = 32
    report_dict = {}

    exam_ids = grouped["exam_id"].tolist()
    reports = grouped["extracted_findings"].tolist()

    for i in tqdm(range(0, len(reports), batch_size)):
        batch_reports = reports[i:i+batch_size]
        batch_exam_ids = exam_ids[i:i+batch_size]

        tokens = tokenizer(batch_reports, padding=False, truncation=False, add_special_tokens=False)

        for j, input_ids in enumerate(tokens["input_ids"]):
            if len(input_ids) <= 512:
                report_dict[batch_exam_ids[j]] = batch_reports[j]


    # Build vid_dict using filtered exam_ids
    vid_dict = {}
    for exam in tqdm(report_dict.keys()):
        vid_paths = df[df["exam_id"] == exam]["processed_file_address"].tolist()
        vid_dict[exam] = vid_paths

    bad_vids_count = 0
    good_vids_count = 0
    min_number_batches = 0
    # Group video paths by report
    report_to_vids = defaultdict(list)
    for exam, vids in tqdm(vid_dict.items()):

        good_vids = []
        for vid in vids:
            vid_path = f'/data{vid[15:]}'

            if os.path.exists(vid_path):
                good_vids.append(vid_path)
                good_vids_count += 1
            else:
                bad_vids_count += 1

        if good_vids:
            report = report_dict[exam]
            report_to_vids[report].extend(good_vids)
            
        min_number_batches = max(min_number_batches, len(good_vids))

    num_batches = max(math.ceil(good_vids_count/batch_size), min_number_batches)

    batches = [[] for _ in range(num_batches)]
    batch_lengths = [0] * num_batches

    i = 0
    for report, vids in tqdm(report_to_vids.items()):
        for vid in vids:
            attempts = 0
            found_batch = False
            while attempts < num_batches:
                batch_idx = i % num_batches
                if batch_lengths[batch_idx] < batch_size:
                    batches[batch_idx].append((report, vid))
                    batch_lengths[batch_idx] += 1
                    found_batch = True
                    break
                i += 1
                attempts += 1
            if not found_batch:
                raise ValueError("Unable to find batch for video")
            i += 1  # Move to next batch for next video of same report

    batch_sizes = [len(batch) for batch in batches]

    print(batch_sizes)
    print(f"Number of missing videos: {bad_vids_count}")

    # Flatten batches into final list
    report_vid_pairs = [pair for batch in batches for pair in batch]

    return report_vid_pairs, batch_sizes

In [11]:
def process_path(path):
    frames_to_take = 32
    frame_stride = 2
    video_size = 224

    mean = torch.tensor([29.110628, 28.076836, 29.096405]).reshape(3, 1, 1, 1)
    std = torch.tensor([47.989223, 46.456997, 47.20083]).reshape(3, 1, 1, 1)

    mat_data = scipy.io.loadmat(path)
    volume = np.array(mat_data['cropped'])
    volume = crop_and_scale(volume)
    volume = np.repeat(volume[..., None], 3, axis=3)
    volume = np.transpose(volume, (3, 2, 0, 1))
    x = torch.as_tensor(volume, dtype=torch.float)
    x.sub_(mean).div_(std)

    if x.shape[1] < frames_to_take:
        padding = torch.zeros((3, frames_to_take - x.shape[1], video_size, video_size), dtype=torch.float)
        x = torch.cat((x, padding), dim=1)

    return x[:, :frames_to_take:frame_stride, :, :]


def batch_generator(report_vid_pairs, batch_sizes):
    idx = 0
    for batch_size in batch_sizes:
        batch = report_vid_pairs[idx:idx + batch_size]
        idx += batch_size

        # Process this specific batch
        reports = [r for r, _ in batch]
        paths = [v for _, v in batch]

        padded = tokenizer(reports, return_tensors='pt', padding='max_length', truncation=True, max_length=512)
        input_ids = padded["input_ids"]
        attention_mask = padded["attention_mask"]
        token_type_ids = padded.get("token_type_ids")

        with ThreadPoolExecutor() as executor:
            videos = list(executor.map(process_path, paths))
        video_tensor = torch.stack(videos)

        if token_type_ids is not None:
            yield video_tensor, input_ids, attention_mask, token_type_ids
        else:
            yield video_tensor, input_ids, attention_mask

In [12]:

test_report_vid_pairs, test_batch_sizes = get_report_vid_pairs(df_test)

val_report_vid_pairs, val_batch_sizes = get_report_vid_pairs(df_val)

train_report_vid_pairs, train_batch_sizes = get_report_vid_pairs(df_train)

  0%|          | 0/13 [00:00<?, ?it/s]

100%|██████████| 13/13 [00:04<00:00,  2.97it/s]
100%|██████████| 394/394 [00:00<00:00, 1630.83it/s]
100%|██████████| 394/394 [00:04<00:00, 81.08it/s] 
100%|██████████| 367/367 [00:00<00:00, 191408.80it/s]


[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31]
Number of missing videos: 189


100%|██████████| 3/3 [00:01<00:00,  2.79it/s]
100%|██████████| 94/94 [00:00<00:00, 1655.21it/s]
100%|██████████| 94/94 [00:00<00:00, 106.96it/s]
100%|██████████| 87/87 [00:00<00:00, 140618.28it/s]


[32, 32, 32, 32, 32, 32, 32, 31, 31, 31, 31, 31]
Number of missing videos: 64


100%|██████████| 1192/1192 [06:47<00:00,  2.92it/s]
100%|██████████| 37889/37889 [00:26<00:00, 1435.74it/s]
100%|██████████| 37889/37889 [05:12<00:00, 121.32it/s]
100%|██████████| 35301/35301 [00:00<00:00, 168861.26it/s]

[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,




In [14]:
batch_size = 32
lr = 4e-5
weight_decay = 1e-6
num_epochs = 60
save_path = "/checkpoints/best_model.pt"
data_path = os.path.expanduser("~/model_data/weights")
cuda_device = 1

DEVICE = torch.device(f'cuda:{cuda_device}' if torch.cuda.is_available() else 'cpu')
EMBED_DIM = 512
TEXT_MODEL_FROZEN_LAYERS = 6
TEXT_MODEL_NAME = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract"

# === Load Echo Encoder ===
checkpoint = torch.load(os.path.join(data_path, "echo_prime_encoder.pt"), map_location=DEVICE)
echo_encoder = torchvision.models.video.mvit_v2_s()
echo_encoder.head[-1] = nn.Linear(echo_encoder.head[-1].in_features, EMBED_DIM)
echo_encoder.load_state_dict(checkpoint)
echo_encoder.eval()
echo_encoder.to(DEVICE)
for param in echo_encoder.parameters():
    param.requires_grad = False

text_encoder_full = AutoModel.from_pretrained(TEXT_MODEL_NAME)
for layer in text_encoder_full.encoder.layer[:TEXT_MODEL_FROZEN_LAYERS]:
    for param in layer.parameters():
        param.requires_grad = False

class TextEncoder(nn.Module):
    def __init__(self, base_encoder, embed_dim):
        super().__init__()
        self.base_encoder = base_encoder
        self.projector = nn.Linear(base_encoder.config.hidden_size, embed_dim)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.base_encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled = outputs.last_hidden_state[:, 0, :]
        return self.projector(pooled)

text_encoder = TextEncoder(text_encoder_full, EMBED_DIM).to(DEVICE)

In [16]:
 # === Loss & Optimizer ===
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(text_encoder.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

# === Training Loop ===
best_val_loss = float('inf')

for epoch in range(num_epochs):
    text_encoder.train()
    total_train_loss = 0

    for batch in batch_generator(train_report_vid_pairs, train_batch_sizes):
        if len(batch) == 4:
            video_tensor, input_ids, attention_mask, token_type_ids = batch
        else:
            video_tensor, input_ids, attention_mask = batch
            token_type_ids = None

        video_tensor = video_tensor.to(DEVICE)
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        # if token_type_ids:
        #     token_type_ids.to(DEVICE)

        with torch.no_grad():
            video_embeds = echo_encoder(video_tensor)
            video_embeds = F.normalize(video_embeds, dim=1)

        text_embeds = text_encoder(input_ids, attention_mask)
        text_embeds = F.normalize(text_embeds, dim=1)

        sim_matrix = torch.matmul(text_embeds, video_embeds.T)
        target = torch.arange(batch_size).to(DEVICE)
        loss = criterion(sim_matrix, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / 100

    # === Validation ===
    text_encoder.eval()
    total_val_loss = 0

    with torch.no_grad():
        for batch in batch_generator(val_report_vid_pairs, val_batch_sizes):
            if len(batch) == 4:
                video_tensor, input_ids, attention_mask, token_type_ids = batch
            else:
                video_tensor, input_ids, attention_mask = batch
                token_type_ids = None
            
            video_tensor = video_tensor.to(DEVICE)
            input_ids = input_ids.to(DEVICE)
            attention_mask = attention_mask.to(DEVICE)

            video_embeds = echo_encoder(video_tensor)
            video_embeds = F.normalize(video_embeds, dim=1)

            text_embeds = text_encoder(input_ids, attention_mask)
            text_embeds = F.normalize(text_embeds, dim=1)

            sim_matrix = torch.matmul(text_embeds, video_embeds.T)
            target = torch.arange(batch_size).to(DEVICE)
            loss = criterion(sim_matrix, target)
            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / 20
    scheduler.step(avg_val_loss)

    print(f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(text_encoder.state_dict(), save_path)
        print(f"✅ Saved new best model at epoch {epoch + 1} with val loss {best_val_loss:.4f}")


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.35 GiB. GPU 1 has a total capacity of 15.77 GiB of which 717.12 MiB is free. Process 2089 has 15.07 GiB memory in use. Of the allocated memory 7.42 GiB is allocated by PyTorch, and 6.50 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)