In [None]:
from transformers import CLIPProcessor, CLIPModel
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Resize, Normalize, ToTensor
from PIL import Image
import os
import os
import glob
import pydicom
import pandas as pd
from transformers import CLIPModel, CLIPTokenizer
import torch
from PIL import Image
from torch.utils.data import Dataset

import torch.nn.functional as F
from torch import nn
from transformers import CLIPProcessor, CLIPModel
import text_aug as TA
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(0)
np.random.seed(0)

In [None]:
dataCSV=pd.read_csv('../Data/2D/2d.csv')

In [None]:
#shuffle the data
dataCSV=dataCSV.sample(frac=1).reset_index(drop=True)

In [None]:
#start with small dataset
# dataCSV=dataCSV[:100]

In [None]:
dataCSV=dataCSV.fillna("0")

In [None]:
dataCSV['image_type']=dataCSV['image_type'].replace("x-ray","X-Ray")
dataCSV['image_type']=dataCSV['image_type'].replace("X-ray","X-Ray")

dataCSV['image_type']=dataCSV['image_type'].replace("xRay","X-Ray")
dataCSV['image_type']=dataCSV['image_type'].replace("Xray","X-Ray")

#remove spaces in the image_type
dataCSV['image_type']=dataCSV['image_type'].str.replace(' ', '')

In [None]:
image_paths = dataCSV['image_dir'].tolist()
regions=dataCSV['region'].tolist()
image_types=dataCSV['image_type'].tolist()
addition_infos=dataCSV['Additional_Info'].tolist()
image_index=dataCSV.index.tolist()

In [None]:
texts=[]
for i in range(len(image_paths)):
    image_type=image_types[i]
    addition_info=addition_infos[i]
    region=regions[i]

    text=region+" "+image_type
    if addition_info!="0":
        text=text+", "+addition_info
    texts.append(text)

dataCSV['text']=texts


In [None]:
#print(set(dataCSV['text']))

In [None]:
#Load the pretrained Model
model_name = "openai/clip-vit-large-patch14"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
tokenizer = CLIPTokenizer.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model= model.to(device)

In [None]:
#freeze the text projection of the model
for param in model.text_projection.parameters():
    param.requires_grad = False

#freeze the text encoder of the model
# Freeze the text encoder
for param in model.text_model.parameters():
    param.requires_grad = False




In [None]:
def process_image_path_to_pixels(file_path):
    # Determine file type
    _, ext = os.path.splitext(file_path)
    ext = ext.lower()

    # Load image
    if ext in ['.jpg', '.jpeg']:
        image = Image.open(file_path).convert('RGB')
    elif ext == '.png':
        image = Image.open(file_path).convert('RGB')  # Convert to 3 channels
    elif ext == '.dicom':
        dicom_image = pydicom.dcmread(file_path)
        image_array = dicom_image.pixel_array
        if len(image_array.shape) == 2:  # Single-channel
            image_array = np.stack((image_array,)*3, axis=-1)
        image = Image.fromarray(image_array)
    else:
        raise ValueError("Unsupported file format")

    return image


In [None]:
from transformers import CLIPProcessor

class CLIPDataset(Dataset):
    def __init__(self, image_paths, texts):
        self.image_paths = image_paths
        self.texts = texts
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        text = self.texts[idx]

        try:
            image = process_image_path_to_pixels(img_path)
        except (IOError, FileNotFoundError):
            # print(f"Error loading image {img_path}, using dummy image instead")
            #return dummy tensor
            text='dummy'
            image=torch.zeros((3,224,224))


        # Process image and text together
        processed = self.processor(text=text, images=image, return_tensors="pt", padding="max_length")
        processed['input_ids']=processed['input_ids'].squeeze(0)
        processed['attention_mask']=processed['attention_mask'].squeeze(0)
        processed['pixel_values']=processed['pixel_values'].squeeze(0)
        return processed


In [None]:
dataset=CLIPDataset(image_paths, texts)

In [None]:
dataloader= DataLoader(dataset, batch_size=54, shuffle=False, num_workers=0)

In [None]:

epochs=10
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)


In [None]:
def infonce_loss(img_embed, text_embed, temperature=0.3):
    # img_embed=F.normalize(img_embed,dim=1, p=2)
    # text_embed=F.normalize(text_embed,dim=1,p=2)
    logits = torch.mm(img_embed, text_embed.t()) / temperature
    labels = torch.arange(img_embed.size(0)).to(img_embed.device)
    return F.cross_entropy(logits, labels)


In [None]:
#use 2 GPU
model = nn.DataParallel(model, device_ids=[0, 1])  # assuming you want to use GPUs 0 and 1
model=model.to(device)

In [None]:
#print unfrozen layers
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name+" is unfrozen")
    else:
        print(name+" is frozen")

In [None]:
margin=0.5

In [None]:

def contrastive_loss(image_features, text_features, temperature=0.07):
    # Normalize features
    image_features = F.normalize(image_features, p=2, dim=-1)
    text_features = F.normalize(text_features, p=2, dim=-1)

    # Cosine similarity as logits
    logits = torch.matmul(image_features, text_features.T) / temperature

    # Labels (diagonal elements are positives)
    labels = torch.arange(len(logits), device=logits.device)

    # Symmetrize the loss
    loss_i = F.cross_entropy(logits, labels)
    loss_t = F.cross_entropy(logits.T, labels)
    return (loss_i + loss_t) / 2


In [None]:

for epoch in range(epochs):
    train_loss=0
    step=0
    model.train()
    for batch in dataloader:
        optimizer.zero_grad(set_to_none=True)
    
        if batch is None:
            continue
        
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        pixel_values = batch["pixel_values"].to(device)
        #plot the pixel values using matplotlib
        # plt.imshow(pixel_values[0].permute(1,2,0).cpu().numpy(),cmap='gray')
        # plt.show()
        outputs = model(pixel_values=pixel_values, attention_mask=attention_mask, input_ids=input_ids)

        input_ids = input_ids.squeeze(1)
        attention_mask = attention_mask.squeeze(1)
        loss = contrastive_loss(outputs.image_embeds, outputs.text_embeds)
        train_loss += loss.item()

        # Backward pass and optimizer step
        loss.backward()
        optimizer.step()
        if step%1000==0:
            print(f"Epoch [{epoch+1}/{epochs}], Step [{step}/{len(dataloader)}], Loss: {loss.item()}")
        step+=1
        #save the model
        if step%100==0:
            original_model=model.module
            torch.save(original_model.state_dict(), "clip_model_temp.pth")
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {train_loss/len(dataloader)}")
    #save the model
    original_model=model.module
    torch.save(original_model.state_dict(), "clip_model_epoch_"+str(epoch)+".pth")
