# Dataset example (LFCC)

In [None]:
from datasets.dataset import LADataset, collate_fn
from torch.utils.data import DataLoader
from spafe.utils.vis import show_features

# Configure txtpath and data directory
txtpath = "datasets/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt"
datadir = "datasets/LA/ASVspoof2019_LA_train/"

# Training dataset
train_dataset = LADataset(split="train", transforms="lfcc",
                          n_fft=512, num_features=20, txtpath=txtpath, datadir=datadir)
# visualize features
show_features(train_dataset[0][0].numpy(), "Linear Frequency Cepstral Coefficients", "LFCC Index", "Frame Index")

# Simple framework Training

In [None]:
from models.resnet import ResNet
import torch
import torch.nn as nn
import os
import torch.nn.functional as F
from tqdm import tqdm

# Define device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Current device is : {device}")

## Trainer function

In [None]:
from utils.train import trainer

## Training config

In [None]:
model = ResNet(3, 256, '18', nclasses=2).to(device)

# Configure txtpath and data directory
train_txtpath = "datasets/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt"
train_datadir = "datasets/LA/ASVspoof2019_LA_train/"

# Configure txtpath and data directory
val_txtpath = "datasets/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt"
val_datadir = "datasets/LA/ASVspoof2019_LA_dev/"

train_dataset = LADataset(split="train", transforms="lfcc", n_fft=512, num_features=20, txtpath=train_txtpath, datadir=train_datadir)
val_dataset = LADataset(split="dev", transforms="lfcc", n_fft=512, num_features=20, txtpath=val_txtpath, datadir=val_datadir)

learning_rate = 0.0003
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.5, step_size=10)
epochs = 10
batch_size = 64

In [None]:
trainer(model, train_dataset, val_dataset, optimizer, scheduler, epochs, loss_opt='ce', batch_size=batch_size, exp_name="baseline", device=device)