In [1]:
import librosa
import librosa.display
import matplotlib.pyplot as plt

def display_spectrogram(data):
    plt.imshow(data, aspect='auto', cmap='inferno')
    plt.colorbar()
    plt.title('Spectrogram')
    plt.xlabel('Time')
    plt.ylabel('Frequency')
    plt.show()


In [2]:
import torch
import os
import torch.nn as nn
from torch.utils.data import DataLoader
from utils.spectrogram_dataset import SpectrogramDataset
from utils.spectrogram_dataset import get_files_recursive

class ContextEncoder(nn.Module):
    def __init__(self):
        super(ContextEncoder, self).__init__()
        # Define the context encoder architecture (e.g., a Vision Transformer)
        self.encoder = nn.Sequential(
            torch.nn.Linear(5,5),
            torch.nn.ReLU()
        )

    def forward(self, x):
        return self.encoder(x)

class Predictor(nn.Module):
    def __init__(self):
        super(Predictor, self).__init__()
        # Define the predictor architecture
        self.predictor = nn.Sequential(
            torch.nn.Linear(5,5),
            torch.nn.ReLU()
        )

    def forward(self, x):
        return self.predictor(x)

class TargetEncoder(nn.Module):
    def __init__(self):
        super(TargetEncoder, self).__init__()
        # Define the target encoder architecture (e.g., a Vision Transformer)
        self.encoder = nn.Sequential(
            torch.nn.Linear(5,5),
            torch.nn.ReLU()
        )

    def forward(self, x):
        return self.encoder(x)

class IJEPA(nn.Module):
    def __init__(self):
        super(IJEPA, self).__init__()
        self.context_encoder = ContextEncoder()
        self.predictor = Predictor()
        self.target_encoder = TargetEncoder()

    def forward(self, context, target):
        context_embedding = self.context_encoder(context)
        predicted_target_embedding = self.predictor(context_embedding)
        target_embedding = self.target_encoder(target)
        return predicted_target_embedding, target_embedding

# Define the loss function (e.g., Mean Squared Error)
loss_function = nn.MSELoss()

# Instantiate the I-JEPA model
model = IJEPA()

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 20

dataset_path = './data2'
sample_len = 4096


# if os.path.exists('spectrogram_dataset.pth'):
#     dataset = torch.load('spectrogram_dataset.pth')

# else:
#     dataset = SpectrogramDataset(
#         folder=dataset_path,
#         files=get_files_recursive(dataset_path),  # Assuming you've imported the get_files function
#         crop_frames=sample_len
#     )
#     torch.save(dataset, 'spectrogram_dataset.pth')

dataset = SpectrogramDataset(
        folder=dataset_path,
        files=get_files_recursive(dataset_path),  # Assuming you've imported the get_files function
        crop_frames=sample_len
    )

print(dataset[0].shape)
print(dataset[1].shape)
dataloader = DataLoader(
    dataset[0],
    batch_size=50
)


print(dataset[1])
display_spectrogram(dataset[1])


# Training loop
for epoch in range(epochs):
    for context in dataloader:
        predicted_target_embedding, target_embedding = model(context, context)
        loss = loss_function(predicted_target_embedding, target_embedding)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


tensor([[-80.0000, -80.0000, -80.0000,  ..., -29.9332, -29.0311, -22.8096],
        [-80.0000, -80.0000, -80.0000,  ..., -26.3848, -20.0094, -13.9930],
        [-80.0000, -80.0000, -80.0000,  ..., -19.8444, -19.7309, -17.3792],
        ...,
        [-80.0000, -80.0000, -80.0000,  ..., -80.0000, -80.0000, -80.0000],
        [-80.0000, -80.0000, -80.0000,  ..., -80.0000, -80.0000, -80.0000],
        [-80.0000, -80.0000, -80.0000,  ..., -80.0000, -80.0000, -80.0000]])
tensor([[-80., -80., -80.,  ...,   0.,   0.,   0.],
        [-80., -80., -80.,  ...,   0.,   0.,   0.],
        [-80., -80., -80.,  ...,   0.,   0.,   0.],
        ...,
        [-80., -80., -80.,  ...,   0.,   0.,   0.],
        [-80., -80., -80.,  ...,   0.,   0.,   0.],
        [-80., -80., -80.,  ...,   0.,   0.,   0.]])
tensor([[-80.0000, -80.0000, -80.0000,  ..., -29.9332, -29.0311, -22.8096],
        [-80.0000, -80.0000, -80.0000,  ..., -26.3848, -20.0094, -13.9930],
        [-80.0000, -80.0000, -80.0000,  ..., -19.844

KeyboardInterrupt: 