In [6]:
import numpy as np
import torch
import os
from torchvision import transforms
from PIL import Image
import random
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

LOCAL_IMAGE_DIRECTORY = "/Users/Shared/Files From c.localized/Work/Code/Virtual Zebrafish/YouTube Movies/Frames/"
FILENAME = "frame_0052.jpg"



In [7]:



# Function to display tensor as image
def show_image(tensor_img):
    # Remove the batch dimension
    img = tensor_img.squeeze(0)
    
    # Transpose the tensor dimensions to (height, width, channels)
    img = img.permute(1, 2, 0)
    
    # Convert the tensor to a NumPy array and detach it from the computation graph
    img = img.detach().numpy()
    
    # Display the image using matplotlib
    plt.imshow(img)
    plt.axis('off')  # Turn off axis numbers and ticks
    plt.show()


# Function to load an image and convert it to a PyTorch tensor
def load_image(file_path):

    # Open the image file
    with Image.open(file_path).convert('RGB') as img:  # Convert to RGB for 3 channels
        # Define transformations: resize to 224x224 and then convert to tensor
        transform = transforms.Compose([
            transforms.Resize((25, 25)),  # Resize to 224x224 pixels
            transforms.ToTensor()  # Convert the PIL Image to a tensor
        ])

        
        
        # Apply the transformations to the image and add a batch dimension
        img_tensor = transform(img).unsqueeze(0)
        
        return img_tensor

# Example usage
file_path = LOCAL_IMAGE_DIRECTORY + FILENAME

tensor_img = load_image(file_path)

# Verify the shape (should be [1, 3, 224, 224] for a single RGB image)
#print(tensor_img.shape)

# Now you can pass `tensor_img` into your neural network


In [8]:
class ConvReluNet(nn.Module):
    def __init__(self, output_size, training=True):
        super(ConvReluNet, self).__init__()

        # First Convolutional Layer followed by ReLU activation
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()

        # Second Convolutional Layer followed by ReLU activation
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()

        # Fully Connected Layer (for training)
        # after two conv layers with same padding, the size remains 25x25
        self.fc = nn.Linear(in_features=64*25*25, out_features=output_size)

        # State for adding the fully connected layer or not
        self.training_mode = training

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))

        # If the network is in training mode, then add the dense layer
        if self.training_mode:
            x = x.view(x.size(0), -1)  # Flatten the tensor
            x = self.fc(x)
        return x

# Usage
output_size = 10  # Specify your desired output size
net = ConvReluNet(output_size)

# Example input

output = net(tensor_img)

print(output.shape)  # Should be torch.Size([1, output_size])



torch.Size([1, 10])


In [9]:
def extract_number_from_filename(fname):
    return int(fname.split('_')[1].split('.')[0])


class TripletDataset(torch.utils.data.Dataset):
    def __init__(self, folder_path):
        self.image_paths = sorted([os.path.join(folder_path, fname) for fname in os.listdir(folder_path) if fname.endswith('.jpg')], key=lambda x: extract_number_from_filename(os.path.basename(x)))
        self.pos_margin = 20
        self.transform = transforms.Compose([
            transforms.Resize((25, 25)),  # Resize to 25 by 25 pixels
            transforms.ToTensor()  # Convert the PIL Image to a tensor
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        # Load image utility function
        def load_image(img_path):
            with Image.open(img_path) as img:
                return self.transform(img)
        
        # Anchor Image
        anchor_img = load_image(self.image_paths[index])
        
        # Positive Image within [margin] frames of the anchor
        positive_index = random.choice(list(range(max(0, index-self.pos_margin), index)) + list(range(index+1, min(len(self.image_paths), index+self.pos_margin+1))))
        positive_img = load_image(self.image_paths[positive_index])
        
        # Negative Image more than [margin] frames away from anchor
        negative_indices = list(range(0, max(0, index-self.pos_margin))) + list(range(min(len(self.image_paths), index+self.pos_margin +1), len(self.image_paths)))
        negative_index = random.choice(negative_indices)
        negative_img = load_image(self.image_paths[negative_index])

        #print(f"Anchor index is {index}, Positive index is {positive_index}, negative index is {negative_index}")

        return anchor_img, positive_img, negative_img

folder_path = LOCAL_IMAGE_DIRECTORY
triplet_dataset = TripletDataset(folder_path)
triplet_loader = torch.utils.data.DataLoader(dataset=triplet_dataset, batch_size=32, shuffle=True)

# Example usage:
anchor_imgs, positive_imgs, negative_imgs = next(iter(triplet_loader))
z_a = net(anchor_imgs)
z_p = net(positive_imgs)
z_n = net(negative_imgs)


In [10]:
class InfoNCELoss(nn.Module):
    def __init__(self, tau=0.1):
        """
        Initialize InfoNCE Loss module.

        Parameters:
        - tau : Temperature parameter for scaling logits
        """
        super(InfoNCELoss, self).__init__()
        self.tau = tau

    def forward(self, z_a, z_p, z_n):
        """
        Compute the InfoNCE loss.

        Parameters:
        - z_a : Tensor of shape (batch_size, feature_dim) representing anchor points
        - z_p : Tensor of shape (batch_size, feature_dim) representing positive examples
        - z_n : Tensor of shape (batch_size, feature_dim) representing negative examples

        Returns:
        - Loss : InfoNCE loss value
        """
        # Compute the similarity (dot product) for positive and negative pairs
        sim_pos = torch.sum(z_a * z_p, dim=-1)  # (batch_size,)
        sim_neg = torch.sum(z_a * z_n, dim=-1)  # (batch_size,)

        # Calculate the logits: concatenate the similarities and divide by the temperature
        logits = torch.cat([sim_pos.unsqueeze(1), sim_neg.unsqueeze(1)], dim=1) / self.tau

        # Define the labels (anchor and positive are always 0 since they are the "true" match)
        labels = torch.zeros(logits.size(0), dtype=torch.long).to(z_a.device)

        # Compute the cross entropy loss
        loss = F.cross_entropy(logits, labels)
        
        return loss

In [61]:
infonce = InfoNCELoss()
infonce(z_a,z_p,z_n)

tensor(0.6849, grad_fn=<NllLossBackward0>)

In [68]:
import torch
from torch import optim
from tqdm import tqdm


output_size = 10  # Specify your desired output size
model = ConvReluNet(output_size)
# Assuming your model is called 'model' and it's on the desired device (e.g., 'cuda')
model = model.to('mps')
model.train()

folder_path = LOCAL_IMAGE_DIRECTORY
triplet_dataset = TripletDataset(folder_path)
dataloader = torch.utils.data.DataLoader(dataset=triplet_dataset, batch_size=256, shuffle=True)


loss_function = InfoNCELoss()

# Initialize the optimizer
optimizer = optim.Adam(model.parameters())

# Number of epochs
num_epochs = 3

for epoch in range(num_epochs):
    total_loss = 0.0

    for anchors, positives, negatives in dataloader:
        # Move the data to the desired device
        anchors = anchors.to('mps')
        positives = positives.to('mps')
        negatives = negatives.to('mps')

        # Forward pass
        anchor_out = model(anchors)
        positive_out = model(positives)
        negative_out = model(negatives)

        # Compute the loss
        loss = loss_function(anchor_out, positive_out, negative_out)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")


Epoch [1/3], Loss: 1.4844
Epoch [2/3], Loss: 0.2670
Epoch [3/3], Loss: 0.1231
