In [1]:
# Imports for Tensor
import bisect
import csv
import itertools
import matplotlib.pyplot as plt
import math
import numpy as np
import os
import pandas as pd
import shutil
import sys
from collections import OrderedDict
from datetime import datetime
from tempfile import TemporaryDirectory
from typing import Tuple

from tqdm.auto import tqdm

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.utils.tensorboard import SummaryWriter
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset
from torchvision import transforms

from diffusers import StableDiffusionPipeline
from datasets import load_dataset

sys.path.append("../")

%load_ext autoreload
%autoreload 2

In [2]:
# from common.dog import DoG, LDoG, PDoG
# from cnn_models import model_size
# from cnn_models import CNNClassifierLight, ResNet, LabelSmoothingCrossEntropy
# from cnn_models import EfficientNet, ShuffleNet, ResNet
from data_processing.math import MathDataset
from data_processing.parkinsons import ParkinsonsDataset, health_class_labels, HealthSampler
from data_processing.seed import SEEDDataset, emotion_class_labels, EmotionSampler
from data_processing.general_dataset import GeneralPreprocessor, GeneralDataset, GeneralSampler
from data_processing.general_dataset import general_class_labels, general_dataset_map
# from training import train_class, evaluate_class, TrainingConfig
# from visualization import *

In [3]:
random_seed = 205 #205 Gave a good split for training
np.random.seed(random_seed)
torch.manual_seed(random_seed)

In [4]:
# Datapaths
datadirs = {}
datahome = '/data/shared/signal-diffusion'
# datahome = '/mnt/d/data/signal-diffusion'

# Math dataset
datadirs['math'] = f'{datahome}/eeg_math'
datadirs['math-stft'] = os.path.join(datadirs['math'], 'stfts')

# Parkinsons dataset
datadirs['parkinsons'] = f'{datahome}/parkinsons/'
datadirs['parkinsons-stft'] = os.path.join(datadirs['parkinsons'], 'stfts')

#SEED dataset
datadirs['seed'] = f'{datahome}/seed/'
datadirs['seed-stft'] = os.path.join(datadirs['seed'], "stfts")

In [29]:
nsamps = 2000

preprocessor = GeneralPreprocessor(datadirs, nsamps, ovr_perc=0.5, fs=125, bin_spacing="log")
preprocessor.preprocess(resolution=256, train_frac=0.8, val_frac=0.2, test_frac=0.0)

In [30]:
BATCH_SIZE = 32
parkinsons_real_train_dataset = ParkinsonsDataset(datadirs['parkinsons-stft'], split="train", transform=None)
seed_real_train_dataset = SEEDDataset(datadirs['seed-stft'], split="train")
real_train_datasets = [parkinsons_real_train_dataset, seed_real_train_dataset]
real_train_set = GeneralDataset(real_train_datasets, split='train')
train_samp = GeneralSampler(real_train_datasets, BATCH_SIZE, split='train')


In [31]:
print(len(seed_real_train_dataset))
print(len(parkinsons_real_train_dataset))
print(len(real_train_set))
print(len(train_samp))

In [32]:
weights = train_samp.weights / torch.min(train_samp.weights)
weights

In [36]:
plt.plot(weights)

In [37]:
print(set(weights.numpy()))
print(sum(weights.numpy()))
for wgt in set(weights.numpy()):
    print(f"{wgt:.4f}: add {np.ceil(wgt)} every {np.floor(1/(wgt - np.floor(wgt) + 1e-5))} else {np.floor(wgt)}")

# Resample Data

In order to have a class-balanced version on-disk.

Select data indices by weight, then add an extra copy every i'th (or skip copy every i'th) to match fractional weight.

## Make data copies

In [11]:
def make_copies(idxs, metadata, out_dir, copy_fn):
    count = 0
    for i in idxs:
        i = i.item()
        # Get dataset & index
        dataset_idx = bisect.bisect_right(real_train_set.cumulative_sizes, i)
        if dataset_idx == 0:
            sample_idx = i
        else:
            sample_idx = i - real_train_set.cumulative_sizes[dataset_idx - 1]
        dataset = real_train_set.datasets[dataset_idx]
        # Get metadata
        md = dataset.metadata.iloc[sample_idx]
        # Make output directory
        os.makedirs(os.path.dirname(os.path.join(out_dir, dataset.name, md['file_name'])), exist_ok=True)
        # Number of copies
        copies = copy_fn(count)
        # Save copies
        for c in range(copies):
            fn = md['file_name']
            new_fn = f'{dataset.name}/{fn[:-4]}_{c}.png'
            mdc = md.copy()
            mdc['file_name'] = new_fn
            metadata.append(mdc)
            shutil.copyfile(os.path.join(dataset.datadir, fn), os.path.join(out_dir, new_fn))
        count += 1

In [38]:
out_dir = f'{datahome}/reweighted_meta_dataset'
metadata = []

In [39]:
# 2.0124 - add a 3rd copy every 80 samples
idxs = torch.argwhere(torch.logical_and(weights > 2, weights < 2.02)).reshape(-1)
copy_fn = lambda count: 3 if count % 80 == 0 else 2
make_copies(idxs, metadata, out_dir, copy_fn)
print(len(metadata))
print(metadata[-1])

In [40]:
# 2.0386 - add a 3rd copy every 25 samples
idxs = torch.argwhere(torch.logical_and(weights > 2.02, weights < 2.1)).reshape(-1)
copy_fn = lambda count: 3 if count % 25 == 0 else 2
make_copies(idxs, metadata, out_dir, copy_fn)
print(len(metadata))
print(metadata[-1])

In [41]:
# 1 - add 1 copy always
idxs = torch.argwhere(torch.logical_and(weights > 0.9, weights < 1.01)).reshape(-1)
copy_fn = lambda count: 1
make_copies(idxs, metadata, out_dir, copy_fn)
print(len(metadata))
print(metadata[-1])

In [42]:
# 1.0262 - add 2 copies every 38 samples
idxs = torch.argwhere(torch.logical_and(weights > 1.01, weights < 1.1)).reshape(-1)
copy_fn = lambda count: 2 if count % 38 == 0 else 1
make_copies(idxs, metadata, out_dir, copy_fn)
print(len(metadata))
print(metadata[-1])

## Save metadata

In [43]:
metapd = pd.DataFrame(metadata)

In [44]:
metapd

In [45]:
metapd.to_csv(os.path.join(out_dir, 'metadata.csv'), index=False)