In [35]:
import numpy as np
import random
import plotly.graph_objects as go

num_samples = 1000
fakeDatasetProperties = dict(
    min_length = 20,
    max_length = 100,
    min_amplitude = 1,
    max_amplitude = 4,
    min_frequency = 1,
    max_frequency = 10,
    sampling_rate = 200)

class SineDataset:
    def __init__(self, dataset_properties):
        self.fakeDatasetProperties = dataset_properties
    def normalize_dataset(self, dataset, conditions, train_split=0.8):
        trainDataset = dataset[:int(len(dataset) * train_split)]
        # Calculate mean and standard deviation for each sample
        means = [np.mean(sample) for sample in trainDataset]
        stds = [np.std(sample) for sample in trainDataset]
        # Calculate weights based on the length of each sample
        weights = [len(sample) for sample in trainDataset]
        # Compute the global mean and standard deviation as a weighted average of the individual means and standard deviations
        self.mean_traj = np.array([np.average(means, weights=weights)])
        self.std_traj = np.array([np.average(stds, weights=weights)])
        # Normalize each sample using the global mean and standard deviation
        normalized_dataset = [(sample - self.mean_traj) / self.std_traj for sample in dataset]
        train_trajs = normalized_dataset[:int(len(normalized_dataset) * train_split)]
        test_trajs = normalized_dataset[int(len(normalized_dataset) * train_split):]
        # TODO should the conditions be split up like this?
        self.mean_condition = np.array([np.mean(conditions[:int(len(conditions) * train_split)])])
        self.std_condition = np.array([np.std(conditions[:int(len(conditions) * train_split)])])
        normalized_conditions = [(condition - self.mean_condition) / self.std_condition for condition in conditions]
        train_conditions = normalized_conditions[:int(len(normalized_conditions) * train_split)]
        test_conditions = normalized_conditions[int(len(normalized_conditions) * train_split):]
        self.mean_length = np.array([np.mean(weights)])
        self.std_length = np.array([np.std(weights)])
        return (train_trajs, train_conditions), (test_trajs, test_conditions)
    
    def generate_sine_wave(self, min_length, max_length, min_amplitude, max_amplitude, min_frequency, max_frequency, sampling_rate):
        # randomly select the sample length, amplitude, and frequency
        sample_length = random.randint(min_length, max_length)
        amplitude = random.uniform(min_amplitude, max_amplitude)
        # Ensure each sample includes at least two oscillations
        frequency = random.uniform(max(min_frequency, 2 / (sample_length / sampling_rate)), max_frequency)
        # generate the time points for the sine wave
        time = np.arange(sample_length) / sampling_rate
        # generate the sine wave
        wave = amplitude * np.sin(2 * np.pi * frequency * time)
        wave = np.expand_dims(wave, axis=1)
        return wave, amplitude, frequency

    def generate_dataset(self, num_samples, **kwargs):
        # create an empty list to store the sine waves
        dataset = []
        params = []
        
        for _ in range(num_samples):
            wave, amplitude, frequency = self.generate_sine_wave(**kwargs)
            dataset.append(wave)
            params.append([amplitude, frequency])
        
        return dataset, params

sineDataset = SineDataset(fakeDatasetProperties)
trajs, conditions = sineDataset.generate_dataset(num_samples, **fakeDatasetProperties)
(train_trajs, train_conditions), (test_trajs, test_conditions) = sineDataset.normalize_dataset(trajs, conditions)
print(sineDataset.mean_traj)
print(sineDataset.std_traj)


[0.14165295]
[1.76083908]


In [36]:
fig = go.Figure()
# find the shortest frequency in dataset
shortest_freq = fakeDatasetProperties['max_frequency']
for i in range(num_samples):
    if conditions[i][1] < shortest_freq:
        shortest_freq = conditions[i][1]
        index = i
# find highest frequency in dataset
highest_freq = fakeDatasetProperties['min_frequency']
fig.add_trace(go.Scatter(y=np.squeeze(trajs[index]), mode='markers', name='lowest frequency'))
for i in range(num_samples):
    if conditions[i][1] > highest_freq:
        highest_freq = conditions[i][1]
        index = i
fig.add_trace(go.Scatter(y=np.squeeze(trajs[index]), mode='markers', name='highest frequency'))
fig.show()
fig = go.Figure()
for i in range(100):
    fig.add_trace(go.Scatter(y=np.squeeze(trajs[i]), mode='lines'))
fig.update_layout(title='First 100 samples from the dataset', xaxis_title='Time', yaxis_title='Amplitude')
fig.show()

In [37]:
from src.mouseGAN.LR_schedulers import *
from src.mouseGAN.model_config import Config, LR_SCHEDULERS, LOSS_FUNC, \
    C_MiniBatchDisc, C_Discriminator, C_Generator, C_EMA_Plateua_Sch, \
    C_Step_Sch, C_LossGap_Sch
from src.mouseGAN.models import MouseGAN
from src.mouseGAN.experimentTracker import initialize_wandb
from src.mouseGAN.dataProcessing import MouseGAN_Data
from src.mouseGAN.dataset import getDataloader


IN_COLAB = False
LOAD_PRETRAINED = True
BATCH_SIZE = 256
num_epochs = 1000
num_feats = train_trajs[0].shape[1]
latent_dim = 20
num_target_feats = train_conditions[0].shape[0]
numBatches = len(train_trajs)//BATCH_SIZE
MAX_SEQ_LEN = max([len(traj) for traj in train_trajs + test_trajs])

D_config = C_Discriminator(lr=0.01, bidirectional=True, hidden_units=128, 
                            num_lstm_layers=1, 
                            useEndDeviationLoss=False,
                            gradient_maxNorm = 1.0,
                            spectral_norm = True)
G_config = C_Generator(lr=0.0002, hidden_units=128, num_lstm_layers=3, drop_prob=0.4,
                # layer_normalization = True,
                residual_connections = True,
                gradient_maxNorm = 1.0,
                useSeqLengthLoss=False,
                useOutsideTargetLoss=False,
                usePathLengthLoss=False)

D_sch_config = C_LossGap_Sch(cooldown=int(numBatches/4), lr_shrinkMin=0.1, lr_growthMax=2.0, 
                            discLossDecay=0.8, lr_max = D_config.lr, lr_min = 0.0001, restart_after=None)
# G_sch_config = C_EMA_Plateua_Sch(patience=BATCH_SIZE, cooldown=int(BATCH_SIZE/8), factor=0.5, ema_alpha=0.4)

config = Config(num_epochs, BATCH_SIZE, num_feats, latent_dim, num_target_feats, MAX_SEQ_LEN,
                discriminator=D_config, generator=G_config, 
                D_lr_scheduler=D_sch_config, #G_lr_scheduler=G_sch_config,
                locationMSELoss = True)

## verifying the mean trajectory is centered around zero (even class distribution)
# dataset.plotMeanPath()
trainLoader = getDataloader(train_trajs, train_conditions, config.BATCH_SIZE)
testLoader = getDataloader(test_trajs, test_conditions, config.BATCH_SIZE)

 
# if IN_COLAB:
# run = initialize_wandb(config, tempProjectName='sineGAN')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gan = MouseGAN(sineDataset, trainLoader, testLoader, device, config, IN_COLAB=IN_COLAB, verbose=True, printBatch=True)
if LOAD_PRETRAINED:
    gan.loadPretrained(startingEpoch=300)

gan.visualTrainingVerfication(samples=10)

# print(gan.discriminator)
# print(gan.generator)

# gan.find_learning_rates_for_GAN()
gan.train(modelSaveInterval=20, catchErrors=False, visualCheckInterval=5)
# if IN_COLAB:
#     wandb.finish()


torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.



AttributeError: 'SineDataset' object has no attribute 'mean_length'

In [None]:
gan.visualTrainingVerfication(samples=50)

In [None]:
gan.visualTrainingVerfication(samples=10)

In [None]:
gan.visualTrainingVerfication(samples=10)

In [None]:
gan.train(modelSaveInterval=3, catchErrors=False)

In [None]:
# gan.save_models('final')
gan.loadPretrained(startingEpoch=99)

In [None]:
for epoch in ['final']:
    gan.loadPretrained(startingEpoch=epoch)
    gan.visualTrainingVerfication()

In [None]:
a = torch.tensor(1.5)
a.requires_grad_()
b = torch.greater(a, 1.0)
# b = torch.round(a)
b.backward()
a.grad