# Stargan Main: Data partition for training and testing of the StargGAN

## Step 1 segment and filter

Create the windows of SSVEP signals for each subject

In [5]:
from core.utils2 import generate_ref_signal
from core.utils2 import fbcca
import numpy as np
import os

from core.utils import segment_and_filter_all_subjects

import random

import shutil

In [None]:
# constants: 
SRC_DIR = './data/processed'
TARGET_DIR = './data/final'
WINDOW_SIZE = 1024
LOWCUT = 6
HIGHCUT = 54
FS = 250
ORDER = 6

os.makedirs(TARGET_DIR, exist_ok=True)
# def segment_and_filter_all_subjects(src_dir, target_dir, window_size, lowcut, highcut, fs, order=6):
segment_and_filter_all_subjects(SRC_DIR, TARGET_DIR, WINDOW_SIZE, LOWCUT, HIGHCUT, FS, ORDER)

## Step 2 Train test split

### Calculate accuracies

In [7]:


# get the ref signals
N = 1024
FS = 250
N_HARMONICS = 3
N_SUBBANDS = 6
FREQS = [8, 10, 12, 14]
LOWEST_FREQ = 2
UPMOST_FREQ = 54
W = 2.2
IDX_FREQS = [0, 2, 4, 6] # [8, 10, 12, 14]
DATA_DIR = "./data/final"
FREQ_PHASE_SRC = "./data/raw/Freq_Phase.mat"

ref_signals = generate_ref_signal(FREQ_PHASE_SRC, freqs=FREQS, N=N, n_harmonics=N_HARMONICS, fs=FS)

accuracies = np.empty((35, len(IDX_FREQS)))
for i in range(1, 36):
    for freq in IDX_FREQS:
        actual_freq = IDX_FREQS.index(freq)
        label_dir = os.path.join(DATA_DIR, str(freq))
        for j in range(0, 4):
            file_path = os.path.join(label_dir, "S" + str(i) + "_" + str(j) + ".npy")
            if os.path.exists(file_path):
                segment = np.load(file_path).swapaxes(0, 1)
                pred = fbcca(segment, FS, N_SUBBANDS, "M1", W, ref_signals, LOWEST_FREQ, UPMOST_FREQ)
            accuracies[i-1, actual_freq] = 1 if pred == actual_freq + 1 else 0

print(accuracies)

[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 0.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [0. 0. 0. 0.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]]


In [8]:
acc_means = np.mean(accuracies, axis=1)
perfect_idx = np.where(acc_means == 1)[0]
perfect_idx

array([ 0,  1,  2,  3,  4,  5,  6,  7,  9, 10, 11, 12, 13, 14, 15, 16, 17,
       18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 33, 34])

### Split de data

Depending on the accuracy of each subject. We will train the model only with perfect scoring subjects

In [9]:
SRC_DIR = "./data/final"
TARGET_DIR = "./data"
SPLIT_RATIO = 0.85 # for train/validation

labels = os.listdir(SRC_DIR)
for label in labels:
    label_dir = os.path.join(SRC_DIR, label)
    files = os.listdir(label_dir)
    # Filter the files to only use the perfect accuracies
    files = [file for file in files if int(file.split("_")[0][1:]) in perfect_idx]
    print(f"Using a total of {len(files)} samples to process")
    random.shuffle(files)
    split_idx = int(len(files) * SPLIT_RATIO)
    train_files = files[:split_idx]
    val_files = files[split_idx:]
    train_dir = os.path.join(TARGET_DIR, "train", label)
    val_dir = os.path.join(TARGET_DIR, "val", label)
    if not os.path.exists(train_dir):
        os.makedirs(train_dir)
    if not os.path.exists(val_dir):
        os.makedirs(val_dir)
    for file in train_files:
        shutil.copy(os.path.join(label_dir, file), os.path.join(train_dir, file))
    for file in val_files:
        shutil.copy(os.path.join(label_dir, file), os.path.join(val_dir, file))

# finally, rename the labels folders to 0, 1, 2 ... N instead of 2, 4, 6, 8
files = sorted(os.listdir("data/train/"))
print(files)
for file in files:
    shutil.move("data/train/" + file, "data/train/" + str(files.index(file)))
    shutil.move("data/val/" + file, "data/val/" + str(files.index(file)))
    

Using a total of 192 samples to process
Using a total of 192 samples to process
Using a total of 192 samples to process
Using a total of 192 samples to process
['0', '2', '4', '6']


## Step 3: Train of the StarGan

In [None]:
%conda install -n stargan2 ipykernel --update-deps --force-reinstall

: 

In [1]:
import os
from munch import Munch
from torch.backends import cudnn
import torch
from core.data_loader import get_train_loader, get_test_loader
from core.solver import Solver
from core.wing import align_faces
# import importlib

# importlib.reload(get_train_loader)



def subdirs(dname):
    return [d for d in os.listdir(dname) if os.path.isdir(os.path.join(dname, d))]

class Args:
    img_size = 1024
    num_domains = 4
    latent_dim = 16
    hidden_dim = 512
    style_dim = 64
    lambda_reg = 1
    lambda_cyc = 1
    lambda_sty = 1
    lambda_ds = 1
    ds_iter = 5000
    w_hpf = 0 # For SSVEP
    randcrop_prob = 0.5
    total_iters = 5000
    resume_iter = 0
    batch_size = 8
    val_batch_size = 8
    lr = 5e-4
    f_lr = 1e-6
    beta1 = 0.0
    beta2 = 0.99
    weight_decay = 1e-4
    num_outs_per_domain = 10
    mode = 'train'
    num_workers = 16
    seed = 777
    train_img_dir = 'data/train'
    val_img_dir = 'data/val'
    sample_dir = 'expr/samples'
    checkpoint_dir = 'expr/checkpoints'
    eval_dir = 'expr/eval'
    result_dir = 'expr/results'
    # src_dir = 'assets/representative/celeba_hq/src'
    # ref_dir = 'assets/representative/celeba_hq/ref'
    # inp_dir = 'assets/representative/custom/female'
    # out_dir = 'assets/representative/celeba_hq/src/female'
    # wing_path = 'expr/checkpoints/wing.ckpt'
    # lm_path = 'expr/checkpoints/celeba_lm_mean.npz'
    print_every = 10
    sample_every = 250
    save_every = 1000
    eval_every = 2500
    runType = 'w_4s'

args = Args()

# print(args)
cudnn.benchmark = True
torch.manual_seed(args.seed)

solver = Solver(args)

assert len(subdirs(args.train_img_dir)) == args.num_domains
assert len(subdirs(args.val_img_dir)) == args.num_domains
loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                        which='source',
                                        img_size=args.img_size,
                                        batch_size=args.batch_size,
                                        prob=args.randcrop_prob,
                                        num_workers=args.num_workers),
                ref=get_train_loader(root=args.train_img_dir,
                                        which='reference',
                                        img_size=args.img_size,
                                        batch_size=args.batch_size,
                                        prob=args.randcrop_prob,
                                        num_workers=args.num_workers),
                val=get_test_loader(root=args.val_img_dir,
                                    img_size=args.img_size,
                                    batch_size=args.val_batch_size,
                                    shuffle=True,
                                    num_workers=args.num_workers))
solver.train(loaders)

Note: you may need to restart the kernel to use updated packages.
Training on cuda
Number of parameters of generator: 12112129
Number of parameters of mapping_network: 4079872
Number of parameters of style_encoder: 7008622
Number of parameters of discriminator: 6879346
Initializing generator...
Initializing mapping_network...
Initializing style_encoder...
Initializing discriminator...
Preparing DataLoader to fetch source images during the training phase...
Preparing DataLoader to fetch reference images during the training phase...
Preparing DataLoader for the generation phase...
Start training...
Elapsed time [0:00:08], Iteration [10/5000], D/latent_real: [0.0004] D/latent_fake: [0.0177] D/latent_reg: [0.0010] D/ref_real: [0.0005] D/ref_fake: [0.0062] D/ref_reg: [0.0010] G/latent_adv: [9.2899] G/latent_sty: [0.9560] G/latent_ds: [0.4710] G/latent_cyc: [4.5077] G/ref_adv: [6.4575] G/ref_sty: [0.0690] G/ref_ds: [0.0368] G/ref_cyc: [4.5105] G/lambda_ds: [0.9980]
Elapsed time [0:00:10], It

AttributeError: 'list' object has no attribute 'cpu'