<a href="https://colab.research.google.com/github/morzahavi/corsound/blob/main/Untitled17.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [105]:
# gpu_info = !nvidia-smi
# gpu_info = '\n'.join(gpu_info)
# if gpu_info.find('failed') >= 0:
#   print('Not connected to a GPU')
# else:
#   print(gpu_info)

In [106]:
# from psutil import virtual_memory
# ram_gb = virtual_memory().total / 1e9
# print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

# if ram_gb < 20:
#   print('Not using a high-RAM runtime')
# else:
#   print('You are using a high-RAM runtime!')

In [107]:
import os

# Check if the extracted directory or file exists (replace 'dataset_directory_or_file_name' with the actual name)
if not os.path.exists("dataset"):
    # Download the dataset
    !gdown --id 1VSexRwiUJmcCyXw3KSEzFGxNhorngHq7

    # Extract the dataset
    !tar xf dataset_classification.tar

    # Install the required library
    !pip install transformers


In [108]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import numpy as np
from tqdm import tqdm
from torch.optim.lr_scheduler import OneCycleLR
import datetime
import pytz

# Wandb import
try:
    import wandb
except ImportError:
    !pip install wandb

# Constants
IMG_SIZE = (224, 224)
NORMALIZE_MEAN = (0.485, 0.456, 0.406)
NORMALIZE_STD = (0.229, 0.224, 0.225)
TIMEZONE = pytz.timezone('Asia/Jerusalem')
CURRENT_TIME = datetime.datetime.now(TIMEZONE).strftime('%Y_%m_%d__%H_%M_%S')


class CFG:
    debug = True
    subset = 1000
    comment = ""
    seed = 101
    backbone = "resnet101"
    batch_size = 8
    epochs = 5
    loss = "binary_crossentropy"
    optimizer = "Adam"
    lr = 1e-4
    token_length = 100


class CustomDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.transform = transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.ToTensor(),
            transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD)
        ])

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image = Image.open(row["image"]).convert("RGB")
        image = self.transform(image)
        text = row["text"]
        inputs = self.tokenizer(text, return_tensors="pt", max_length=self.max_length, padding="max_length", truncation=True)
        input_ids = inputs["input_ids"].squeeze(0)
        attention_mask = inputs["attention_mask"].squeeze(0)
        label = torch.tensor(row["label"], dtype=torch.float32)
        return image, input_ids, attention_mask, label

class MultimodalClassifier(nn.Module):
    def __init__(self):
        super(MultimodalClassifier, self).__init__()
        self.resnet = getattr(models, CFG.backbone)(pretrained=True)

        # Get the in_features before replacing the fc layer with Identity
        resnet_out_features = self.resnet.fc.in_features

        self.resnet.fc = nn.Identity()

        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.fc = nn.Linear(resnet_out_features + self.bert.config.hidden_size, 1)

    def forward(self, image, input_ids, attention_mask):
        image_embed = self.resnet(image)
        text_embed = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        combined = torch.cat([image_embed, text_embed], dim=1)
        output = self.fc(combined)
        return output.squeeze()

# Class Weights
class_counts = train_data["label"].value_counts().to_dict()

# For binary classification:
num_negative = class_counts[0]
num_positive = class_counts[1]

total = num_negative + num_positive

weight_for_0 = (1 / num_negative) * (total) / 2.0
weight_for_1 = (1 / num_positive) * (total) / 2.0

class_weights = {0: weight_for_0, 1: weight_for_1}
print(class_weights)

#
def load_data():
    data_path = "dataset/dataset.parquet"
    data = pd.read_parquet(data_path)
    if CFG.debug:
        data = data.sample(CFG.subset)
    return train_test_split(data, test_size=0.2, random_state=CFG.seed)


def evaluate(model, dataloader):
    model.eval()
    all_labels, all_preds = [], []
    with torch.no_grad():
        for items in tqdm(dataloader, desc="Evaluating"):
            image, input_ids, attention_mask, label = (item.to(device) for item in items)
            output = model(image, input_ids, attention_mask)
            preds = (torch.sigmoid(output) > 0.5).int().cpu().numpy()  # Applying sigmoid before thresholding
            all_preds.extend(preds)
            all_labels.extend(label.cpu().int().numpy())
    return all_labels, all_preds


if __name__ == "__main__":
    cfg_instance = CFG()
    wandb.init(project='caption_prediction', name=CURRENT_TIME, config=vars(cfg_instance))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_data, val_data = load_data()
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    train_dataset = CustomDataset(train_data, tokenizer, CFG.token_length)
    val_dataset = CustomDataset(val_data, tokenizer, CFG.token_length)

    train_dataloader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_dataloader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=4, pin_memory=True)

    model = MultimodalClassifier().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr)
    weights = torch.tensor([class_weights[0], class_weights[1]]).to(device)
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=weights[1])
    scheduler = OneCycleLR(optimizer, max_lr=1e-3, epochs=CFG.epochs, steps_per_epoch=len(train_dataloader))
    wandb.watch(model, log='all')

    for epoch in range(CFG.epochs):
        model.train()
        total_loss = 0
        for items in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{CFG.epochs}"):
            image, input_ids, attention_mask, label = (item.to(device) for item in items)
            optimizer.zero_grad()
            output = model(image, input_ids, attention_mask)
            loss = loss_fn(output, label)
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
        wandb.log({"Train Loss": total_loss/len(train_dataloader)}, step=epoch)
        labels, preds = evaluate(model, val_dataloader)
        metrics = {
            "Accuracy": accuracy_score(labels, preds),
            "Precision": precision_score(labels, preds),
            "Recall": recall_score(labels, preds),
            "F1 Score": f1_score(labels, preds),
        }
        print(metrics)
        wandb.log(metrics, step=epoch)

    torch.save(model.state_dict(), 'model.pth')
    wandb.save('model.pth')
    wandb.finish()





{0: 0.5847953216374269, 1: 3.4482758620689653}


Epoch 1/5: 100%|██████████| 100/100 [00:32<00:00,  3.06it/s]
Evaluating: 100%|██████████| 25/25 [00:02<00:00,  8.96it/s]


{'Accuracy': 0.195, 'Precision': 0.18781725888324874, 'Recall': 0.9736842105263158, 'F1 Score': 0.31489361702127666}


Epoch 2/5: 100%|██████████| 100/100 [00:33<00:00,  2.95it/s]
Evaluating: 100%|██████████| 25/25 [00:02<00:00,  8.93it/s]


{'Accuracy': 0.79, 'Precision': 0.0, 'Recall': 0.0, 'F1 Score': 0.0}


Epoch 3/5: 100%|██████████| 100/100 [00:33<00:00,  3.01it/s]
Evaluating: 100%|██████████| 25/25 [00:02<00:00,  9.11it/s]
  _warn_prf(average, modifier, msg_start, len(result))


{'Accuracy': 0.81, 'Precision': 0.0, 'Recall': 0.0, 'F1 Score': 0.0}


Epoch 4/5: 100%|██████████| 100/100 [00:33<00:00,  3.00it/s]
Evaluating: 100%|██████████| 25/25 [00:02<00:00,  8.98it/s]


{'Accuracy': 0.805, 'Precision': 0.3333333333333333, 'Recall': 0.02631578947368421, 'F1 Score': 0.048780487804878044}


Epoch 5/5: 100%|██████████| 100/100 [00:33<00:00,  2.99it/s]
Evaluating: 100%|██████████| 25/25 [00:02<00:00,  8.99it/s]


{'Accuracy': 0.8, 'Precision': 0.25, 'Recall': 0.02631578947368421, 'F1 Score': 0.04761904761904762}


0,1
Accuracy,▁████
F1 Score,█▁▁▂▂
Precision,▅▁▁█▆
Recall,█▁▁▁▁
Train Loss,▇█▅▄▁

0,1
Accuracy,0.8
F1 Score,0.04762
Precision,0.25
Recall,0.02632
Train Loss,0.92847


In [109]:

print(f"TN: {tn}, TP: {tp}, FN: {fn}, FP: {fp}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall (Sensitivity): {recall:.4f}")
print(f"F1-Score: {f1:.4f}")
print(f"Specificity: {specificity:.4f}")
print(f"False Positive Rate: {fpr:.4f}")
print(f"Negative Predictive Value: {npv:.4f}")

# Inference function
def predict(image_path, caption):
    model.eval()
    with torch.no_grad():
        image = Image.open(image_path).convert("RGB")
        image = train_dataset.transform(image).unsqueeze(0).to(device)
        inputs = tokenizer(caption, return_tensors="pt", max_length=32, padding="max_length", truncation=True)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)
        output = model(image, input_ids, attention_mask)
        pred = (output > 0.5).item()
    return "Positive" if pred == 1 else "Negative"

# Test the inference function
print(predict("dataset/images/00002/000025708.jpg", "A group of people flying a large colorful kite in winter."))

print(f"Number of images/samples being used: {len(data)}")



TN: 168, TP: 0, FN: 32, FP: 0
Accuracy: 0.8400
Precision: 0.0000
Recall (Sensitivity): 0.0000
F1-Score: 0.0000
Specificity: 1.0000
False Positive Rate: 0.0000
Negative Predictive Value: 0.8400
Negative
Number of images/samples being used: 2000
