## Dependencies

In [None]:
import cv2
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from PIL import Image
import glob
import random
import numpy as np
from utils import ImageDataset



cwd = os.path.join(os.getcwd(),"data")
# setup paths
pos_path = os.path.join(cwd,"pos")
neg_path = os.path.join(cwd,"neg")
anc_path = os.path.join(cwd,"anc")

### Setup and manage GPU usage

In [2]:
if torch.cuda.is_available():    
    device = torch.device("cuda")
    current_device= torch.cuda.current_device()
    # Limit memory usage to 80%
    torch.cuda.set_per_process_memory_fraction(0.8,device=current_device)
    # For optimized memory utilization during tensor operations
    torch.backends.cudnn.benchmark = True
    print(f"Using Device: {device}",
          f"\nCurrent GPU: {torch.cuda.get_device_name(current_device)}",
          f"\nCuda version: {torch.version.cuda}",
          f"\ncuDNN available: {torch.backends.cudnn.is_available()}",
          f"\ncuDNN version: {torch.backends.cudnn.version()}",
          f"\nAllocated memory: {torch.cuda.memory_allocated()} bytes",
          f"\nCached memory: {torch.cuda.memory_reserved()} bytes")
else:
    device = torch.device("cpu")
    print(f"Using Device: {device}")


Using Device: cuda 
Current GPU: NVIDIA GeForce RTX 4060 Laptop GPU 
Cuda version: 11.8 
cuDNN available: True 
cuDNN version: 90100 
Allocated memory: 0 bytes 
Cached memory: 0 bytes


## Data preprocessing

In [3]:
# preparing transformations
transform = transforms.Compose([
    transforms.Resize((100,100)),
    transforms.ToTensor()
])

# create dataset
dataset = ImageDataset(anc_path,pos_path,neg_path,transform)
# train-test split
train_size = int(0.7*len(dataset))
val_size = len(dataset)-train_size
train_data,val_data=random_split(dataset,[train_size,val_size])

train_loader = DataLoader(train_data,batch_size=16,shuffle=True,num_workers=4,prefetch_factor=8)
val_loader = DataLoader(val_data,batch_size=16,shuffle=True,num_workers=4,prefetch_factor=8)


