In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt
import numpy as np
import time
from torch.optim.lr_scheduler import StepLR

from transformers import AutoFeatureExtractor, ASTForAudioClassification, AutoConfig
from datasets import load_dataset
import soundfile
import librosa
from transformers import ASTConfig, ASTModel
from tqdm import tqdm


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

In [3]:
y_arrays = []
for o in range(1):
    y_arrays.append(torch.load(f"spectrograms{o}.pt"))
y_data = torch.vstack(y_arrays)

# y_data = torch.load(f"spectrograms{0}.pt")
print(f"y_data: {y_data.shape}")
x_data = torch.load("x_data.pt")
print(f"x_data: {x_data.shape}")

y_val = torch.load("y_validation.pt")
print(f"y_val: {y_val.shape}")

  y_arrays.append(torch.load(f"spectrograms{o}.pt"))


y_data: torch.Size([1000, 1017, 126])
x_data: torch.Size([1000, 10000])
y_val: torch.Size([1000, 10000])


  x_data = torch.load("x_data.pt")
  y_val = torch.load("y_validation.pt")


In [4]:
class PairedDataset(Dataset):
    def __init__(self, x_data, y_data):
        assert len(x_data) == len(y_data), "Datasets must be of the same length"
        self.x_data = x_data
        self.y_data = y_data

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        x = self.x_data[idx]
        y = self.y_data[idx]
        return x, y

In [11]:
torch.swapaxes(y_val[:, 0:100], 0, 1).shape

torch.Size([100, 1000])

In [12]:
sample_size = 100
dataset = PairedDataset(y_data[0:sample_size], torch.swapaxes(y_val[:, 0:100], 0, 1))
training_data = DataLoader(dataset, batch_size=10)
x_data = torch.swapaxes(x_data, 0, 1)[0:sample_size][0]

In [13]:
def linear_function(params, x, device):
    x = x.to(device)
    if params.size(1) > 1:
        a = params[:, 0].type(torch.float).unsqueeze(1).to(device)
        b = params[:, 1].type(torch.float).unsqueeze(1).to(device)
        return a*x + b
    else:
        return torch.zeros_like(params)

def quadratic_function(params, x, device):
    params = params.to(device)
    x = x.to(device)
    y = torch.zeros((params.size(0),x.shape[0])).to(device)
    if params.size(1) > 2:
        for n in range(len(params[0])):
            y += (params[:, n].type(torch.float).unsqueeze(1).to(device))*x**(3-n)
        return y
    else:
        return torch.zeros_like(params)

def cubic_function(params, x, device):
    params = params.to(device)
    x = x.to(device)
    y = torch.zeros((params.size(0),x.shape[0])).to(device)
    if params.size(1) > 3:
        for n in range(len(params[0])):
            y += (params[:, n].type(torch.float).unsqueeze(1).to(device))*x**(3-n)
        return y
    else:
        return torch.zeros_like(params)
    
def sin_function(params, x, device):
    x = x.to(device)
    if params.size(1) == 3:
        amplitude = params[:, 0].type(torch.float).unsqueeze(1).to(device)
        frequency = params[:, 1].type(torch.float).unsqueeze(1).to(device)
        phase = params[:, 2].type(torch.float).unsqueeze(1).to(device)
        return amplitude * torch.sin(2 * torch.pi * frequency * x + phase)
    else:
        return torch.zeros_like(x)

In [14]:
def new_loss(output, target, x):
    derivative_true = torch.autograd.grad(outputs=target.requires_grad_(True), inputs=x, grad_outputs=torch.ones_like(target), allow_unused=True)[0]
    if derivative_true is None:
        derivative_true = torch.zeros_like(output)
    derivative_pred = (torch.roll(output, shifts=-1, dims=0) - torch.roll(output, shifts=1, dims=0)) / (2 * 1e-8)
    mse_function = torch.mean(torch.abs(target - output)**3)
    mse_derivative = torch.mean(torch.abs(derivative_pred - derivative_true)**3)
    return mse_function + mse_derivative

In [15]:
class CustomModel(ASTForAudioClassification):
    def __init__(self, config, functions, x_data, device_name):
        super().__init__(config)
        self.functions = functions
        self.x_data = x_data
        self.params = sum(self.functions[1])
        self.device_name = device_name
        
        del self.classifier.dense

        self.flatten_layer = nn.Flatten()

        self.hidden_embedding = nn.Sequential(
            nn.Linear(932352, 128),
            nn.SELU(),
            nn.Linear(128, 64),
            nn.SELU(),
            nn.Linear(64, self.params),
        )

    def forward(self, input_values, targets):
        print(f"inputs: {input_values.shape}")
        inputs = self.audio_spectrogram_transformer(input_values)
        input = self.classifier.layernorm(inputs[0])
        print(f"input: {input.shape}")
        input = self.flatten_layer(input)
        print(f"flattened: {input.shape}")
        embedding = self.hidden_embedding(input)
        print(f"embedding: {embedding.shape}")
        embedding = embedding.view(-1, 12)
        print(f"viewed: {embedding.shape}")
        
        loss_func = nn.MSELoss()
        start_index = 0
        losses = []
        outputs = []
        
        for f in range(len(self.functions[0])):
            print(f"params: {embedding[:, start_index:start_index+self.functions[1][f]].shape}")
            print(f"x data: {self.x_data.shape}")
            output = self.functions[0][f](
                embedding[:, start_index:start_index+self.functions[1][f]], 
                self.x_data, 
                device=self.device_name
            ).to(self.device_name)
            print(f"output: {output.shape}")
            print(f"targets: {targets.shape}")
            outputs.append(output)
            loss = loss_func(output, targets)
            losses.append(loss)
            start_index += self.functions[1][f]        
        best_index = torch.argmin(torch.tensor(losses))
        best_func = self.functions[0][best_index]
        best_loss, best_out = losses[best_index], outputs[best_index]

        return best_out, best_loss, best_func, outputs, losses

In [16]:
config = ASTConfig()
functions = [[linear_function, quadratic_function, cubic_function, sin_function], [2,3,4,3]]

In [17]:
customModel = CustomModel(config, functions=functions, x_data=x_data, device_name=device).to(device)

In [18]:
loss_func = nn.MSELoss()
optimizer = optim.Adam(customModel.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

In [20]:
epochs = 1
for epoch in range(epochs):
    start_time = time.time()
    train_loss = 0.0
    total_num = 0
    customModel.train()
    
    for train_batch,targets in training_data:
        best_out,_,_,_,_  = customModel(train_batch, targets)
        break


inputs: torch.Size([10, 1017, 126])
input: torch.Size([10, 1214, 768])
flattened: torch.Size([10, 932352])
embedding: torch.Size([10, 12])
viewed: torch.Size([10, 12])
params: torch.Size([10, 2])
x data: torch.Size([1000])
output: torch.Size([10, 1000])
targets: torch.Size([10, 1000])


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:4 and cpu!

In [9]:
dummyData = torch.zeros((10, 1017, 128)).to(device)
print(dummyData.shape)
customModel(dummyData).shape

torch.Size([10, 1017, 128])


torch.Size([10, 1214, 768])

In [None]:
'''r = np.random.randint(sample_size)
sample = data[r, :, 0]
print(sample.shape)
# STFT parameters
n_fft = 256  # Number of FFT components
win_length = 256  # Window length
hop_length = 128  # Number of samples between frames

# Convert sample to complex tensor with the required dimensions
sample = sample.unsqueeze(0)  # Add batch dimension

# Apply STFT
spectrogram = torch.stft(sample, n_fft=n_fft, win_length=win_length, hop_length=hop_length, return_complex=True)
print(spectrogram.shape)
# Compute magnitude spectrogram
magnitude_spectrogram = torch.abs(spectrogram)
print(magnitude_spectrogram.shape)

# Convert to numpy for plotting
spectrogram_np = magnitude_spectrogram.squeeze().cpu().numpy()
print(spectrogram_np.size)
# Plot the spectrogram
plt.figure(figsize=(10, 4))
plt.imshow(20 * np.log10(spectrogram_np + 1e-8), aspect='auto', origin='lower', cmap='inferno')
plt.colorbar(label='Magnitude (dB)')
plt.xlabel('Time')
plt.ylabel('Frequency')
plt.title('Spectrogram of Polynomial Function Output')
plt.show()

plt.plot(data[r, :, 0].detach().cpu().numpy(), "-");
'''