# Setup

The cells from here to the "Get Images and Metadata" section needs to be run for every session.

In [None]:
!nvidia-smi
!pip install exifread

Mon Dec 16 19:59:01 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA L4                      Off | 00000000:00:03.0 Off |                    0 |
| N/A   42C    P8              12W /  72W |      1MiB / 23034MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
%cd /content/drive/MyDrive/T4SG St Jude/New Assets/
!ls


/content/drive/.shortcut-targets-by-id/1OzE9n_wYhYcpK_L5s-1Tx1zZAOXd_f78/T4SG St Jude/New Assets
 00053127-007_TIF.Jpg			  00203343-001_TIF.Jpg
 00053127-011_TIF.Jpg			  00203343-002_TIF.Jpg
 00054250-022_TIF.Jpg			  00203343-003_TIF.Jpg
 00054250-023_TIF.Jpg			  00203343-009_TIF.Jpg
 00054889-030_TIF.Jpg			  00203343-010_TIF.Jpg
 00055143-008_TIF.Jpg			  00203343-012_TIF.Jpg
 00055524-068_TIF.Jpg			  00203343-016_TIF.Jpg
 00055524-069_TIF.Jpg			  00203343-024_TIF.Jpg
 00055524-121_TIF.Jpg			  00203343-028_TIF.Jpg
 00055527-003_TIF.Jpg			  00203343-031_TIF.Jpg
 00055883-01-010c_TIF.Jpg		  00203343-034_TIF.Jpg
 00055883-01-010_TIF.Jpg		  00203343-036_TIF.Jpg
 00091968-643_TIF.Jpg			  00203343-037_TIF.Jpg
 00091968-653_TIF.Jpg			  00203343-038_TIF.Jpg
 00092847-476_TIF.Jpg			  00203343-039_TIF.Jpg
 00092850-021_TIF.Jpg			  00203343-042_TIF.Jpg
 00093026-002_TIF.Jpg			  00203343-043_TIF.Jpg
 00093026-005_TIF.Jpg			  00203343-044_TIF.Jpg
 00093508-134_TIF.Jpg			  00203343-046_TIF.Jpg
 0009

In [None]:
!apt-get install -y exiftool

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
Note, selecting 'libimage-exiftool-perl' instead of 'exiftool'
The following additional packages will be installed:
  libarchive-zip-perl libmime-charset-perl libsombok3 libunicode-linebreak-perl
Suggested packages:
  libposix-strptime-perl libencode-hanextra-perl libpod2-base-perl
The following NEW packages will be installed:
  libarchive-zip-perl libimage-exiftool-perl libmime-charset-perl libsombok3
  libunicode-linebreak-perl
0 upgraded, 5 newly installed, 0 to remove and 49 not upgraded.
Need to get 3,964 kB of archives.
After this operation, 23.5 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 libarchive-zip-perl all 1.68-1 [90.2 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 libimage-exiftool-perl all 12.40+dfsg-1 [3,717 kB]
Get:3 http://archive.ubuntu.com/ubuntu jammy/universe amd64 libmime-charset-perl all 1.012.2-1

# Get Images and Metadata


In [None]:
# Seeing metadata for one file - experimentation

import subprocess
import json

image_path = "00212416-020_TIF.Jpg"

# Run ExifTool and capture metadata as JSON
result = subprocess.run(["exiftool", "-j", image_path], capture_output=True, text=True)

if result.returncode == 0:
    metadata = json.loads(result.stdout)[0]  # Parse JSON output
    print(f"Metadata for {image_path}:")
    # print(json.dumps(metadata, indent=4))  # Pretty-print all metadata
    print(metadata["Subject"])
    print(metadata["Description"])
else:
    print(f"ExifTool failed: {result.stderr}")


Metadata for 00212416-020_TIF.Jpg:
['Rural', 'Flower', 'Growth', 'nature', 'Fair Weather', 'Sunflower', 'Sun']
00212416-020, 07-20-20, Memphis views, rural, farming, sunflowers, Nature, 


Using the metadata extraction code from earlier, separate out images by their Level 1 keywords. Only needs to be run once to retrieve image data and keywords.

In [None]:
import os
import subprocess
import json
from tqdm import tqdm

# Define extensions to skip and invalid keyword filters
extensions_to_skip = ['.gsheet', '.csv', '.bin', '.gz']
invalid_extensions = ['.tif', '.jpg', '.jpeg', '.png']

# Initialize keyword bank and image metadata storage
keyword_bank = set()  # Stores all valid keywords
image_data = []  # Stores tuples of (image_path, one-hot vector)

def is_valid_keyword(keyword):
    """Check if a keyword is valid."""
    if not isinstance(keyword, str) or len(keyword) <= 1:
        return False
    if any(keyword.endswith(ext) for ext in invalid_extensions):
        return False
    return True

# Extract metadata and prune keywords
def extract_and_prune_metadata(image_dir, max_images=None):
    global keyword_bank, image_data

    # Get list of images in the directory
    file_list = os.listdir(image_dir)[:max_images] if max_images else os.listdir(image_dir)

    for filepath_i in tqdm(file_list):
        # Skip non-image files
        if any(filepath_i.endswith(ext) for ext in extensions_to_skip):
            continue

        # Get full file path
        file_path = os.path.join(image_dir, filepath_i)

        # Use ExifTool to extract metadata
        result = subprocess.run(["exiftool", "-j", file_path], capture_output=True, text=True)

        try:
            metadata = json.loads(result.stdout)[0]
            # Extract keywords (Subjects) from metadata
            file_keywords = metadata.get("Subject", [])

            # Prune invalid keywords
            pruned_keywords = [word for word in file_keywords if is_valid_keyword(word)]

            # Update keyword bank
            keyword_bank.update(pruned_keywords)

            # Add metadata to image data
            file_name = os.path.basename(file_path)
            image_data.append((file_name, pruned_keywords))

        except Exception as e:
            print(f"Error processing {file_path}: {e}")

    print(f"Extracted metadata for {len(image_data)} images.")
    print(f"Keyword bank size: {len(keyword_bank)}")

# Generate one-hot encoding for the keywords
def generate_one_hot_vectors():
    global image_data, keyword_bank

    # Map keywords to indices
    keyword_to_index = {keyword: idx for idx, keyword in enumerate(sorted(keyword_bank))}

    # Generate one-hot vectors for each image
    for i, (file_path, keywords) in enumerate(image_data):
        one_hot_vector = [0] * len(keyword_to_index)
        for keyword in keywords:
            if keyword in keyword_to_index:
                one_hot_vector[keyword_to_index[keyword]] = 1
        image_data[i] = (file_path, one_hot_vector)  # Replace keywords with one-hot vector

        # print(f"File: {file_path}, One-Hot Vector: {one_hot_vector}")  # Debug

    print(f"Generated one-hot vectors for {len(image_data)} images.")
    # print(f"Keyword-to-index mapping: {keyword_to_index}")

# Example usage:
image_dir = image_dir = "/content/drive/MyDrive/T4SG St Jude/New Assets/"
# extract_and_prune_metadata(image_dir, max_images=100)

extract_and_prune_metadata(image_dir, max_images=None) # Full set

generate_one_hot_vectors()


100%|██████████| 2024/2024 [07:27<00:00,  4.52it/s]

Extracted metadata for 2015 images.
Keyword bank size: 632
Generated one-hot vectors for 2015 images.





Save image and keyword data. This and the above extraction code would need to be modified for actual layer 1 keywords.

In [None]:
import json

# Paths to save the data
image_data_path = "image_data.json"
keyword_bank_path = "keyword_bank.json"

# Save image_data
with open(image_data_path, 'w') as f:
    json.dump(image_data, f)  # image_data is a list of tuples

# Save keyword_bank
with open(keyword_bank_path, 'w') as f:
    json.dump(list(keyword_bank), f)  # Convert set to list for JSON serialization

print(f"Saved image_data to {image_data_path}")
print(f"Saved keyword_bank to {keyword_bank_path}")

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-24-a3a256eb4233>", line 8, in <cell line: 8>
    with open(image_data_path, 'w') as f:
OSError: [Errno 107] Transport endpoint is not connected: 'image_data.json'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 2099, in showtraceback
    stb = value._render_traceback_()
AttributeError: 'OSError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/ultratb.py", line 1101, in get_records
    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)
  File "/usr/

Write Pytorch classes for DataLoaders and Datasets to read in these images and their keyword labels. This also only needs to be ran once.

In [None]:
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import transforms

class MultiLabelImageDataset(Dataset):
    def __init__(self, image_data, transform=None):
        self.image_data = image_data
        self.transform = transform

    def __len__(self):
        return len(self.image_data)

    def __getitem__(self, idx):
        # Load image and its one-hot labels
        image_path, labels = self.image_data[idx]

        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image_path, image, torch.tensor(labels, dtype=torch.float32)

Import and load dataset.

In [None]:
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.ToTensor(),         # Convert to tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Normalize
])

# Create dataset
dataset = MultiLabelImageDataset(image_data, transform=transform)

# Split dataset into train and test sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Check a few samples from the dataset
for i in range(5):  # Check the first 5 samples
    filepath, image, labels = test_dataset[i]
    print(f"Filepath: {filepath}")
    print(f"Labels: {labels}")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-25-4694a14d34bd>", line 21, in <cell line: 20>
    filepath, image, labels = test_dataset[i]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataset.py", line 412, in __getitem__
    return self.dataset[self.indices[idx]]
  File "<ipython-input-24-ddd77810d453>", line 18, in __getitem__
    image = Image.open(image_path).convert("RGB")
  File "/usr/local/lib/python3.10/dist-packages/PIL/Image.py", line 3466, in open
    filename = os.path.realpath(os.fspath(fp))
  File "/usr/lib/python3.10/posixpath.py", line 397, in realpath
    return abspath(path)
  File "/usr/lib/python3.10/posixpath.py", line 384, in abspath
    cwd = os.getcwd()
OSError: [Errno 107] Transport endpoint is not connected

During handling of the above exception, another exception occurr

In [None]:
print(dataset)

<__main__.MultiLabelImageDataset object at 0x785ee4df31f0>


Save the dataset for future use.

In [None]:
# Save dataset
torch.save(dataset, './dataset.pt')

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-26-f183bb18847c>", line 2, in <cell line: 2>
    torch.save(dataset, './dataset.pt')
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 849, in save
    with _open_zipfile_writer(f) as opened_zipfile:
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 716, in _open_zipfile_writer
    return container(name_or_buffer)
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 687, in __init__
    super().__init__(torch._C.PyTorchFileWriter(self.name))
RuntimeError: Parent directory . does not exist.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 2099, in show

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



OSError: [Errno 107] Transport endpoint is not connected
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-26-f183bb18847c>", line 2, in <cell line: 2>
    torch.save(dataset, './dataset.pt')
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 849, in save
    with _open_zipfile_writer(f) as opened_zipfile:
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 716, in _open_zipfile_writer
    return container(name_or_buffer)
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 687, in __init__
    super().__init__(torch._C.PyTorchFileWriter(self.name))
RuntimeError: Parent directory . does not exist.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packa

# Training the Model
Load and use transfer learning to train a new model

In [None]:
# Load model
import torch.nn as nn
from torchvision.models import resnet50

# Load pre-trained ResNet50
model = resnet50(pretrained=True)

for param in model.parameters():
    param.requires_grad = False

for param in model.fc.parameters():
    param.requires_grad = True

# Replace the final fully connected layer
num_classes = len(keyword_bank)  # Number of unique labels
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_features, num_classes),  # Output matches the number of classes
    nn.Sigmoid()  # Sigmoid for multi-label classification
)

In [None]:
import torch.optim as optim

criterion = nn.BCEWithLogitsLoss()  # Binary Cross-Entropy Loss
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training loop
for epoch in range(10):  # Number of epochs
    model.train()
    running_loss = 0.0

    for filepaths, images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader)}")

Epoch 1, Loss: 0.6931838030908622
Epoch 2, Loss: 0.6931822054526385
Epoch 3, Loss: 0.693178813831479
Epoch 4, Loss: 0.6931770630911285
Epoch 5, Loss: 0.6931744603549733


Save the model for future use.

In [None]:
# Define a path to save the model
MODEL_SAVE_PATH = "model.pth"

torch.save(model.state_dict(), MODEL_SAVE_PATH)

# Loading existing image data, dataset, and model.
Run the cells below if you have already run the above cells in a different runtime to create image data, the keywords, the dataset, and have trained a model.

In [None]:
import json

# Path to the saved data (use the full path if the files are in Drive)
image_data_path = "image_data.json"
keyword_bank_path = "keyword_bank.json"

# Load image_data
with open(image_data_path, 'r') as f:
    image_data = json.load(f)

# Load keyword_bank
with open(keyword_bank_path, 'r') as f:
    keyword_bank = set(json.load(f))  # Convert list back to set

# Check the data
print("Loaded image_data:", image_data[:2])  # Print first 2 items for preview
print("Loaded keyword_bank:", list(keyword_bank)[:5])  # Print first 5 keywords

Loaded image_data: [['00115417-140_TIF.Jpg', [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [None]:
from torch.utils.data import DataLoader

# Load dataset
dataset = torch.load('./dataset.pt')

# Split dataset into train and test sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

  dataset = torch.load('./dataset.pt')


In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet50

# Path to the saved model
MODEL_SAVE_PATH = "model.pth"

# Define the model structure
model = resnet50(pretrained=True)

# Freeze all layers except the fully connected (fc) layer
for param in model.parameters():
    param.requires_grad = False

# Replace the final fully connected layer to match the saved model's structure
num_classes = 632  # Replace with the number of classes used during training
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_features, num_classes),  # Output matches the number of classes
    nn.Sigmoid()  # Sigmoid for multi-label classification
)

# Load the state dictionary
model.load_state_dict(torch.load(MODEL_SAVE_PATH))

# Move the model to the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print("Model loaded successfully.")



Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 217MB/s]
  model.load_state_dict(torch.load(MODEL_SAVE_PATH))


Model loaded successfully.


In [None]:
# Recreate keyword_to_index from keyword_bank
keyword_to_index = {keyword: idx for idx, keyword in enumerate(sorted(keyword_bank))}

# Create reverse mapping for evaluation
index_to_keyword = {idx: keyword for keyword, idx in keyword_to_index.items()}

# Use Model to Predict Metadata on Test Dataset

Instead of using a threshold (which would make sense in the long run), just pick top three most likely for now. This code can be modified to just predict for data that does not yet have metadata.

In [None]:
import os
import csv
import random

# Evaluate model
model.eval()
correct = 0
total = 0
true_positives = 0
false_positives = 0
false_negatives = 0

# Variables for random chance evaluation
random_correct = 0
random_true_positives = 0
random_false_positives = 0
random_false_negatives = 0

# CSV file setup
csv_file = "transfer_evaluation_results.csv"
if os.path.exists(csv_file):
    os.remove(csv_file)
csv_columns = ["filepath", "predicted", "actual", "random_predicted"]

with open(csv_file, mode="w", newline="") as file:
    writer = csv.DictWriter(file, fieldnames=csv_columns)
    writer.writeheader()

    with torch.no_grad():
        for filepaths, images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)

            # Get the top 3 probabilities and their indices
            topk_probs, topk_indices = torch.topk(outputs, 3, dim=1)

            for idx, (image_tensor, top_indices, actual) in enumerate(zip(images, topk_indices, labels)):
                filepath, _, _ = test_dataset[idx]  # Fetch filepath and labels separately

                # Convert predictions and labels to keywords
                predicted_keywords = [index_to_keyword[i.item()] for i in top_indices]
                actual_keywords = [
                    index_to_keyword[i] for i, active in enumerate(actual.tolist()) if active > 0
                ]

                # Generate random predictions
                random_indices = random.sample(range(len(index_to_keyword)), 3)
                random_predicted_keywords = [index_to_keyword[i] for i in random_indices]

                # Write to CSV
                writer.writerow({
                    "filepath": filepath,
                    "predicted": ", ".join(predicted_keywords),
                    "actual": ", ".join(actual_keywords),
                    "random_predicted": ", ".join(random_predicted_keywords)
                })

                # Update precision and recall metrics for model predictions
                for pred_idx in top_indices:
                    if actual[pred_idx] > 0:  # True positive
                        true_positives += 1
                    else:  # False positive
                        false_positives += 1

                for label_idx, label_value in enumerate(actual.tolist()):
                    if label_value > 0 and label_idx not in top_indices:  # False negative
                        false_negatives += 1

                # Update metrics for random predictions
                for pred_idx in random_indices:
                    if actual[pred_idx] > 0:  # True positive
                        random_true_positives += 1
                    else:  # False positive
                        random_false_positives += 1

                for label_idx, label_value in enumerate(actual.tolist()):
                    if label_value > 0 and label_idx not in random_indices:  # False negative
                        random_false_negatives += 1

            # Calculate accuracy (if needed, optional for top-k evaluation)
            total += labels.size(0) * labels.size(1)
            correct += (labels.gather(1, topk_indices) > 0).sum().item()

# Calculate precision, recall, F1 score, and accuracy for model
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
accuracy = correct / total

# Calculate precision, recall, F1 score, and accuracy for random chance
random_precision = random_true_positives / (random_true_positives + random_false_positives) if (random_true_positives + random_false_positives) > 0 else 0
random_recall = random_true_positives / (random_true_positives + random_false_negatives) if (random_true_positives + random_false_negatives) > 0 else 0
random_f1_score = (2 * random_precision * random_recall) / (random_precision + random_recall) if (random_precision + random_recall) > 0 else 0

# Print metrics for model
# print(f"Test Accuracy (Model): {100 * accuracy:.2f}%")
print(f"Precision (Model): {100 * precision:.2f}%")
print(f"Recall (Model): {100 * recall:.2f}%")
print(f"F1 Score (Model): {100 * f1_score:.2f}%")

# Print metrics for random chance
print(f"Precision (Random): {100 * random_precision:.2f}%")
print(f"Recall (Random): {100 * random_recall:.2f}%")
print(f"F1 Score (Random): {100 * random_f1_score:.2f}%")




Precision (Model): 10.01%
Recall (Model): 7.99%
F1 Score (Model): 8.89%
Precision (Random): 0.74%
Recall (Random): 0.59%
F1 Score (Random): 0.66%
