# Setup

In [3]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

In [4]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [5]:
%cd /content/drive/MyDrive/Thesis/

/content/drive/MyDrive/Thesis


In [6]:
!pip install -q -r requirements.txt

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m21.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m287.4/287.4 kB[0m [31m32.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m22.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m69.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m26.6 MB/s[

In [7]:
%cd /content/drive/MyDrive/Thesis/ast

/content/drive/MyDrive/Thesis/ast


In [8]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%reload_ext autoreload


In [9]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.cuda.amp import autocast
from einops.layers.torch import Rearrange

from src import dataloader
from src import models
from src.traintest import train, validate


In [10]:
class Arguments():

  model='ast'
  dataset='speechcommands'
  imagenetpretrain=True
  audiosetpretrain=False

  bal=None
  lr=2.5e-4

  n_epochs=20
  freqm=48
  timem=48
  mixup=0.6
  batch_size=128
  fstride=10
  tstride=10
  dataset_mean=-6.845978
  dataset_std=5.5654526
  audio_length=128
  noise=True

  num_workers = 32
  exp_dir = '/content/drive/MyDrive/Thesis/resout_not_pretrained'
  optimizer = 'adam'
  metrics='acc'
  loss='BCE'              

  lrscheduler_start=5
  lrscheduler_step=1
  lrscheduler_decay=0.85

  warmup = False
  wa = False
  wa_start = 1
  wa_end = 5

  n_print_steps = 100
  n_class = 35
  lr_patience = 2
  save_model = True
args = Arguments()

In [11]:
%cd /content/drive/MyDrive/Thesis/Quaternion_Transformer_Pytorch/

/content/drive/MyDrive/Thesis/Quaternion_Transformer_Pytorch


In [12]:
from core_qnn.quaternion_ops import *
from core_qnn.quaternion_layers import *
from timm.models.layers import to_2tuple,trunc_normal_

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transformer part

In [14]:
def QNorm(x, eps):
    r, i, j, k = torch.chunk(x, chunks=4, dim=-1)
    qnorm = torch.sqrt(r * r + i * i + j * j + k * k + eps)
    r = r / qnorm
    i = i / qnorm
    j = j / qnorm
    k = k / qnorm

    return [r, i, j, k]


class Norm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super().__init__()

        self.size = d_model // 4
        # create two learnable parameters to calibrate normalisation
        self.alpha = nn.Parameter(torch.ones(self.size))
        self.bias = nn.Parameter(torch.zeros(self.size))
        self.eps = eps

    def forward(self, x):
        [r, i, j, k] = QNorm(x, self.eps)

        norm_r = self.alpha * r + self.bias
        norm_i = self.alpha * i + self.bias
        norm_j = self.alpha * j + self.bias
        norm_k = self.alpha * k + self.bias
        norm = torch.cat([norm_r, norm_i, norm_j, norm_k], dim=-1)

        return norm

In [15]:
def quarternion_multiplication(a, b, transpose=True):
    """ Performs hamilton product between two quarternion sequences.
    a = (r,x,y,z)
    b = (r',x',y',z')
    following:
    (rr' - xx' - yy' - zz')  +
    (rx' + xr' + yz' - zy')i +
    (ry' - xz' + yr' + zx')j +
    (rz' + xy' - yx' + zr')k
    """


    ar, ax, ay, az = torch.chunk(a, chunks=4, dim=-1)
    br, bx, by, bz = torch.chunk(b, chunks=4, dim=-1)
    #print(ar.shape)
    #print(br.shape)

    if transpose==True:
        if len(br.shape)>2:
        
            r = torch.matmul(ar,br.transpose(-2,-1)) - torch.matmul(ax,bx.transpose(-2,-1)) - torch.matmul(ay,by.transpose(-2,-1)) - torch.matmul(az,bz.transpose(-2,-1))
            i = torch.matmul(ar,bx.transpose(-2,-1)) + torch.matmul(ax,br.transpose(-2,-1)) + torch.matmul(ay,bz.transpose(-2,-1)) - torch.matmul(az,by.transpose(-2,-1))
            j = torch.matmul(ar,by.transpose(-2,-1)) - torch.matmul(ax,bz.transpose(-2,-1)) + torch.matmul(ay,br.transpose(-2,-1)) + torch.matmul(az,bx.transpose(-2,-1))
            k = torch.matmul(ar,bz.transpose(-2,-1)) + torch.matmul(ax,by.transpose(-2,-1)) - torch.matmul(ay,bx.transpose(-2,-1)) + torch.matmul(az,br.transpose(-2,-1))
            
        else:
            r = torch.matmul(ar, br.t()) - torch.matmul(ax, bx.t()) - torch.matmul(ay, by.t()) - torch.matmul(az, bz.t())
            i = torch.matmul(ar, bx.t()) + torch.matmul(ax, br.t()) + torch.matmul(ay, bz.t()) - torch.matmul(az, by.t())
            j = torch.matmul(ar, by.t()) - torch.matmul(ax, bz.t()) + torch.matmul(ay, br.t()) + torch.matmul(az, bx.t())
            k = torch.matmul(ar, bz.t()) + torch.matmul(ax, by.t()) - torch.matmul(ay, bx.t()) + torch.matmul(az, br.t())
    else:
        r = torch.matmul(ar,br) - torch.matmul(ax,bx) - torch.matmul(ay,by) - torch.matmul(az,bz)
        i = torch.matmul(ar,bx) + torch.matmul(ax,br) + torch.matmul(ay,bz) - torch.matmul(az,by)
        j = torch.matmul(ar,by) - torch.matmul(ax,bz) + torch.matmul(ay,br) + torch.matmul(az,bx)
        k = torch.matmul(ar,bz) + torch.matmul(ax,by) - torch.matmul(ay,bx) + torch.matmul(az,br)
      
    return torch.cat([r, i, j, k], dim=-1)

In [16]:
def ComponentActivation(q, act_func=F.gelu):
    scores_r, scores_i, scores_j, scores_k  = torch.chunk(q, 4, dim=-1)
    if act_func == F.softmax:
      scores_r = act_func(scores_r, dim = -1)
      scores_i = act_func(scores_i, dim = -1)
      scores_j = act_func(scores_j, dim = -1)
      scores_k = act_func(scores_k, dim = -1)
    else:
      scores_r = act_func(scores_r)
      scores_i = act_func(scores_i)
      scores_j = act_func(scores_j)
      scores_k = act_func(scores_k)


    scores = torch.cat([scores_r, scores_i, scores_j, scores_k], dim=-1)
    return scores


In [17]:
#TODO not sure about scale applied to q only
class Attention(nn.Module):

    def __init__(
            self,
            dim, #embed_size
            num_heads=8,
            qkv_bias=False,
            qk_norm=False,
            attn_drop=0.,
            proj_drop=0.,
            norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = QuaternionLinearAutograd(dim, dim * 3, bias=qkv_bias)
        self.q_norm = Norm(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = Norm(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = QuaternionLinearAutograd(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):

        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        
        q = q * self.scale
   
        attn = quarternion_multiplication(q,k)
        # print("Att shape", attn.shape)

        # attn = q @ k.transpose(-2, -1)
        # attn = attn.softmax(dim=-1)
        
        attn =  ComponentActivation(attn, act_func=F.softmax)
        attn = self.attn_drop(attn)

        x = quarternion_multiplication(attn,v, transpose = False)

        # print("x shape att v", x.shape)
        x = x.transpose(1, 2).reshape(B, N, C)
        # print("x shape after transpose and reshape", x.shape)

        x = self.proj(x)
        # print("x shape after lin", x.shape)
        x = self.proj_drop(x)
        return x

In [18]:
class Mlp(nn.Module):

    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        # drop_probs = to_2tuple(drop)

        self.fc1 = QuaternionLinearAutograd(in_features, hidden_features)
        # self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = QuaternionLinearAutograd(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = ComponentActivation(x, act_func=F.gelu)
        # x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

In [19]:
class Block(nn.Module):

    def __init__(
            self,
            dim,
            num_heads,
            mlp_ratio=4.,
            qkv_bias=False,
            qk_norm=False,
            drop=0.,
            attn_drop=0.,
            init_values=None,
            drop_path=0.,
            act_layer=nn.GELU,
            norm_layer=Norm
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=drop,
            norm_layer=norm_layer,
        )
        # self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        # self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            # act_layer=act_layer,
            drop=drop,
        )
        # self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        # self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        # x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
        # x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# QModel

In [27]:
# # Patchinfy using Rearrange (cut w/out CNN) returns quaternions and Qproj
# class PatchifierEmbed(nn.Module):
#     def __init__(self, patch_size=16, embedding_dim = 768):
#         super().__init__()
#         # [B, H, W] with C = 1 omited ([B ,C, H, W])
#         # [B, Number_of_patches, Patch_size=patch_size*patch_size] 
#         # for Q version [B, Number_of_patches, 4*Patch_size=4*patch_size*patch_size] 
#         self.patchifier = Rearrange('b (h p1) (w p2) -> b (h w) (p1 p2)', p1=patch_size, p2=patch_size)
#         # self.proj = nn.Linear(patch_size*patch_size*n_channels, embedding_dim)
#         self.Qproj = QuaternionLinearAutograd(
#             patch_size*patch_size*4, embedding_dim
#         )

#     def forward(self, x):
#         _, n_channels, _, _ = x.shape
#         y = []
#         for channel in range(n_channels):
#           y.append(self.patchifier(x[:, channel,:,:]))
#         zeros = torch.zeros(y[0].shape).to(device)
#         if n_channels==1:
#           # grey => r
#           # zero/black => i,j,k 
#           out = torch.cat((y[0], zeros, zeros, zeros), 2)
#         else:
#           # zero => r
#           # r,g,b => i,j,k 
#           out = torch.cat((zeros, y[0], y[1], y[2]), 2)
#         out = self.Qproj(out)
#         return out

# Patch using convensional CNN as in original paper
# override the timm package to relax the input shape constraint.
# 
class PatchEmbed(nn.Module):
    def __init__(self, img_size=256, patch_size=16, stride = 10, in_chans=4, embed_dim=768):
        super().__init__()

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        # num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        # self.num_patches = num_patches


        # TODO compare qconv or conv works for patchembedding 
        # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.projq = QuaternionConv(in_chans, embed_dim, patch_size, stride)

    def forward(self, x):
        zeros = torch.zeros(x.shape).to(device)
        x = torch.cat((zeros, x, x, x), 1)
        # print("Qx shape: [0, g, g, g]", x.shape)
        x = self.projq(x).flatten(2).transpose(1, 2)
        return x

In [28]:
class ASTModel(nn.Module):
  def __init__(self, label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=False, model_size='base384', verbose=True):
    super(ASTModel, self).__init__()
    # automatcially get the intermediate shape
    self.original_embedding_dim = 768
    num_heads = 12
    mlp_ratio = 4.
    qkv_bias = True
    qk_norm = False
    drop_rate = 0.
    attn_drop_rate = 0.
    depth = 12


    f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
    num_patches = f_dim * t_dim

    self.patch_embed = PatchEmbed()
    self.cls_token = nn.Parameter(torch.zeros(1, 1, self.original_embedding_dim))
    self.dist_token =  nn.Parameter(torch.zeros(1, 1, self.original_embedding_dim))
    # TODO pretrained or sinusoidal
    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.original_embedding_dim))
    
    trunc_normal_(self.pos_embed, std=.02)
    self.pos_drop = nn.Dropout(p=0.)

    self.blocks = nn.Sequential(*[
            Block(
                self.original_embedding_dim,
                num_heads,
                mlp_ratio = mlp_ratio,
                qkv_bias = qkv_bias,
                qk_norm = qk_norm,
                drop = drop_rate,
                attn_drop = attn_drop_rate,
            )
            for i in range(depth)])
    self.norm =  Norm(self.original_embedding_dim)


    # Classifier Head
    self.fc_norm = Norm(self.original_embedding_dim) 
    self.head = nn.Linear(self.original_embedding_dim, label_dim) if label_dim > 0 else nn.Identity()


  
  @autocast()
  def forward(self, x):
    x = x.unsqueeze(1)
    x = x.transpose(2, 3)
    B = x.shape[0]
   
    x = self.patch_embed(x)
    
    cls_tokens = self.cls_token.expand(B, -1, -1)
    dist_token = self.dist_token.expand(B, -1, -1)

    x = torch.cat((cls_tokens, dist_token, x), dim=1)
    x = x + self.pos_embed
    x = self.pos_drop(x)

    x = self.blocks(x)
    x = self.norm(x)

    x = (x[:, 0] + x[:, 1]) / 2
    x = self.fc_norm(x)
    x = self.head(x)
    return x
     

  def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=1024):
    test_input = torch.randn(1, 1, input_fdim, input_tdim)
    test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
    test_out = test_proj(test_input)
    f_dim = test_out.shape[2]
    t_dim = test_out.shape[3]
    return f_dim, t_dim


# QAST: Parameters as in original paper, no pretraining

In [None]:
data_train = '/content/drive/MyDrive/Thesis/datafiles/speechcommand_train_data.json'
data_val_path ='/content/drive/MyDrive/Thesis/datafiles/speechcommand_valid_data.json'
data_eval_path ='/content/drive/MyDrive/Thesis/datafiles/speechcommand_eval_data.json'


label_csv = '/content/drive/MyDrive/Thesis/ast/egs/speechcommands/data/speechcommands_class_labels_indices.csv'


audio_conf = {'num_mel_bins': 128, 'target_length': args.audio_length, 'freqm': args.freqm, 'timem': args.timem, 'mixup': args.mixup, 'dataset': args.dataset, 'mode':'train', 'mean':args.dataset_mean, 'std':args.dataset_std,
                  'noise':args.noise}

val_audio_conf = {'num_mel_bins': 128, 'target_length': args.audio_length, 'freqm': 0, 'timem': 0, 'mixup': 0, 'dataset': args.dataset, 'mode':'validation', 'mean':args.dataset_mean, 'std':args.dataset_std, 'noise':False}
eval_audio_conf = {'num_mel_bins': 128, 'target_length': args.audio_length, 'freqm': 0, 'timem': 0, 'mixup': 0, 'dataset': args.dataset, 'mode':'evaluation', 'mean':args.dataset_mean, 'std':args.dataset_std, 'noise':False}

train_loader = torch.utils.data.DataLoader(
            dataloader.AudiosetDataset(data_train, label_csv=label_csv, audio_conf=audio_conf),
            batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

eval_loader = torch.utils.data.DataLoader(
        dataloader.AudiosetDataset(data_eval_path, label_csv=label_csv, audio_conf=val_audio_conf),
        batch_size=args.batch_size*2, shuffle=False, num_workers=args.num_workers, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
        dataloader.AudiosetDataset(data_val_path, label_csv=label_csv, audio_conf=val_audio_conf),
        batch_size=args.batch_size*2, shuffle=False, num_workers=args.num_workers, pin_memory=True)

---------------the train dataloader---------------
now using following mask: 48 freq, 48 time
now using mix-up with rate 0.600000
now process speechcommands
use dataset mean -6.846 and std 5.565 to normalize the input.
now use noise augmentation
number of classes is 35
---------------the validation dataloader---------------
now using following mask: 0 freq, 0 time
now using mix-up with rate 0.000000
now process speechcommands
use dataset mean -6.846 and std 5.565 to normalize the input.
number of classes is 35
---------------the validation dataloader---------------
now using following mask: 0 freq, 0 time
now using mix-up with rate 0.000000
now process speechcommands
use dataset mean -6.846 and std 5.565 to normalize the input.
number of classes is 35




In [None]:
ast_mdl = ASTModel(label_dim=args.n_class, fstride=args.fstride, tstride=args.tstride, input_fdim=128,
                                input_tdim=args.audio_length, imagenet_pretrain=args.imagenetpretrain,
                                audioset_pretrain=args.audiosetpretrain, model_size='base384')

In [None]:
ast_mdl.to(device)

ASTModel(
  (patch_embed): PatchEmbed(
    (projq): QuaternionConv(in_channels=1, out_channels=192, bias=True, kernel_size=(16, 16), stride=10, padding=0, init_criterion=glorot, weight_init=quaternion, seed=1158, rotation=False, q_format=True, operation=convolution2d)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): Norm()
      (attn): Attention(
        (qkv): QuaternionLinearAutograd(in_features=192, out_features=576, bias=True, init_criterion=glorot, weight_init=quaternion, rotation=False, seed=694)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): QuaternionLinearAutograd(in_features=192, out_features=192, bias=True, init_criterion=glorot, weight_init=quaternion, rotation=False, seed=764)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): Norm()
      (mlp): Mlp(
        (fc1): QuaternionLinearAutograd(in_features=192, out_features

In [None]:
print('Now starting training for {:d} epochs'.format(args.n_epochs))

train(ast_mdl, train_loader, val_loader, args)

Now starting training for 20 epochs
running on cuda
Total parameter number is : 21.665 million
Total trainable parameter number is : 21.665 million
now training with speechcommands, main metrics: acc, loss function: BCEWithLogitsLoss(), learning rate scheduler: <torch.optim.lr_scheduler.MultiStepLR object at 0x7f294d682c10>
The learning rate scheduler starts at 5 epoch with decay rate of 0.850 every 1 epochs
current #steps=0, #epochs=1
start training...
---------------
2023-04-14 13:27:12.379812
current #epochs=1, #steps=0
Epoch: [1][100/662]	Per Sample Total Time 0.00529	Per Sample Data Time 0.00042	Per Sample DNN Time 0.00487	Train Loss 0.1286	
Epoch: [1][200/662]	Per Sample Total Time 0.00508	Per Sample Data Time 0.00021	Per Sample DNN Time 0.00487	Train Loss 0.1285	
Epoch: [1][300/662]	Per Sample Total Time 0.00500	Per Sample Data Time 0.00014	Per Sample DNN Time 0.00486	Train Loss 0.1283	
Epoch: [1][400/662]	Per Sample Total Time 0.00497	Per Sample Data Time 0.00011	Per Sample DNN



Epoch: [2][38/662]	Per Sample Total Time 0.00593	Per Sample Data Time 0.00105	Per Sample DNN Time 0.00488	Train Loss 0.1269	
Epoch: [2][138/662]	Per Sample Total Time 0.00516	Per Sample Data Time 0.00030	Per Sample DNN Time 0.00486	Train Loss 0.1267	
Epoch: [2][238/662]	Per Sample Total Time 0.00503	Per Sample Data Time 0.00017	Per Sample DNN Time 0.00486	Train Loss 0.1266	
Epoch: [2][338/662]	Per Sample Total Time 0.00498	Per Sample Data Time 0.00012	Per Sample DNN Time 0.00486	Train Loss 0.1265	
Epoch: [2][438/662]	Per Sample Total Time 0.00495	Per Sample Data Time 0.00010	Per Sample DNN Time 0.00486	Train Loss 0.1263	
Epoch: [2][538/662]	Per Sample Total Time 0.00494	Per Sample Data Time 0.00008	Per Sample DNN Time 0.00486	Train Loss 0.1260	
Epoch: [2][638/662]	Per Sample Total Time 0.00492	Per Sample Data Time 0.00007	Per Sample DNN Time 0.00486	Train Loss 0.1257	
start validation
acc: 0.166015
AUC: 0.762052
Avg Precision: 0.081434
Avg Recall: 0.675836
d_prime: 1.008222
train_loss:



Epoch: [3][76/662]	Per Sample Total Time 0.00543	Per Sample Data Time 0.00056	Per Sample DNN Time 0.00487	Train Loss 0.1223	
Epoch: [3][176/662]	Per Sample Total Time 0.00511	Per Sample Data Time 0.00025	Per Sample DNN Time 0.00486	Train Loss 0.1221	
Epoch: [3][276/662]	Per Sample Total Time 0.00502	Per Sample Data Time 0.00016	Per Sample DNN Time 0.00486	Train Loss 0.1217	
Epoch: [3][376/662]	Per Sample Total Time 0.00497	Per Sample Data Time 0.00012	Per Sample DNN Time 0.00486	Train Loss 0.1214	
Epoch: [3][476/662]	Per Sample Total Time 0.00495	Per Sample Data Time 0.00009	Per Sample DNN Time 0.00486	Train Loss 0.1209	
Epoch: [3][576/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00008	Per Sample DNN Time 0.00486	Train Loss 0.1206	
start validation
acc: 0.267208
AUC: 0.874567
Avg Precision: 0.113632
Avg Recall: 0.729903
d_prime: 1.623866
train_loss: 0.120348
valid_loss: 0.099219
validation finished
Epoch-3 lr: 0.00025
epoch 3 training time: 445.494
---------------
2023-04-



Epoch: [4][14/662]	Per Sample Total Time 0.00767	Per Sample Data Time 0.00276	Per Sample DNN Time 0.00491	Train Loss 0.1175	
Epoch: [4][114/662]	Per Sample Total Time 0.00523	Per Sample Data Time 0.00036	Per Sample DNN Time 0.00486	Train Loss 0.1166	
Epoch: [4][214/662]	Per Sample Total Time 0.00506	Per Sample Data Time 0.00019	Per Sample DNN Time 0.00486	Train Loss 0.1162	
Epoch: [4][314/662]	Per Sample Total Time 0.00499	Per Sample Data Time 0.00013	Per Sample DNN Time 0.00486	Train Loss 0.1159	
Epoch: [4][414/662]	Per Sample Total Time 0.00496	Per Sample Data Time 0.00010	Per Sample DNN Time 0.00486	Train Loss 0.1152	
Epoch: [4][514/662]	Per Sample Total Time 0.00494	Per Sample Data Time 0.00008	Per Sample DNN Time 0.00486	Train Loss 0.1146	
Epoch: [4][614/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00007	Per Sample DNN Time 0.00486	Train Loss 0.1141	
start validation
acc: 0.499750
AUC: 0.946366
Avg Precision: 0.109597
Avg Recall: 0.922931
d_prime: 2.277723
train_loss:



Epoch: [5][52/662]	Per Sample Total Time 0.00563	Per Sample Data Time 0.00076	Per Sample DNN Time 0.00487	Train Loss 0.1109	
Epoch: [5][152/662]	Per Sample Total Time 0.00513	Per Sample Data Time 0.00026	Per Sample DNN Time 0.00486	Train Loss 0.1095	
Epoch: [5][252/662]	Per Sample Total Time 0.00502	Per Sample Data Time 0.00016	Per Sample DNN Time 0.00486	Train Loss 0.1090	
Epoch: [5][352/662]	Per Sample Total Time 0.00497	Per Sample Data Time 0.00012	Per Sample DNN Time 0.00486	Train Loss 0.1086	
Epoch: [5][452/662]	Per Sample Total Time 0.00495	Per Sample Data Time 0.00009	Per Sample DNN Time 0.00486	Train Loss 0.1081	
Epoch: [5][552/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00007	Per Sample DNN Time 0.00486	Train Loss 0.1077	
Epoch: [5][652/662]	Per Sample Total Time 0.00492	Per Sample Data Time 0.00006	Per Sample DNN Time 0.00486	Train Loss 0.1073	
start validation
acc: 0.627592
AUC: 0.969998
Avg Precision: 0.111862
Avg Recall: 0.966746
d_prime: 2.659792
train_loss:



Epoch: [6][90/662]	Per Sample Total Time 0.00533	Per Sample Data Time 0.00046	Per Sample DNN Time 0.00487	Train Loss 0.1036	
Epoch: [6][190/662]	Per Sample Total Time 0.00508	Per Sample Data Time 0.00022	Per Sample DNN Time 0.00486	Train Loss 0.1034	
Epoch: [6][290/662]	Per Sample Total Time 0.00500	Per Sample Data Time 0.00015	Per Sample DNN Time 0.00486	Train Loss 0.1032	
Epoch: [6][390/662]	Per Sample Total Time 0.00497	Per Sample Data Time 0.00011	Per Sample DNN Time 0.00486	Train Loss 0.1030	
Epoch: [6][490/662]	Per Sample Total Time 0.00494	Per Sample Data Time 0.00009	Per Sample DNN Time 0.00486	Train Loss 0.1027	
Epoch: [6][590/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00007	Per Sample DNN Time 0.00486	Train Loss 0.1024	
start validation
acc: 0.688709
AUC: 0.979370
Avg Precision: 0.122149
Avg Recall: 0.977230
d_prime: 2.886283
train_loss: 0.102152
valid_loss: 0.052821
validation finished
Epoch-6 lr: 0.00018062499999999999
epoch 6 training time: 445.633
---------



Epoch: [7][28/662]	Per Sample Total Time 0.00635	Per Sample Data Time 0.00146	Per Sample DNN Time 0.00489	Train Loss 0.1004	
Epoch: [7][128/662]	Per Sample Total Time 0.00520	Per Sample Data Time 0.00033	Per Sample DNN Time 0.00487	Train Loss 0.0996	
Epoch: [7][228/662]	Per Sample Total Time 0.00505	Per Sample Data Time 0.00019	Per Sample DNN Time 0.00486	Train Loss 0.0992	
Epoch: [7][328/662]	Per Sample Total Time 0.00499	Per Sample Data Time 0.00013	Per Sample DNN Time 0.00486	Train Loss 0.0987	
Epoch: [7][428/662]	Per Sample Total Time 0.00496	Per Sample Data Time 0.00010	Per Sample DNN Time 0.00486	Train Loss 0.0986	
Epoch: [7][528/662]	Per Sample Total Time 0.00494	Per Sample Data Time 0.00008	Per Sample DNN Time 0.00486	Train Loss 0.0985	
Epoch: [7][628/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00007	Per Sample DNN Time 0.00486	Train Loss 0.0983	
start validation
acc: 0.749224
AUC: 0.984162
Avg Precision: 0.131088
Avg Recall: 0.983804
d_prime: 3.038399
train_loss:



Epoch: [8][66/662]	Per Sample Total Time 0.00553	Per Sample Data Time 0.00064	Per Sample DNN Time 0.00489	Train Loss 0.0959	
Epoch: [8][166/662]	Per Sample Total Time 0.00513	Per Sample Data Time 0.00026	Per Sample DNN Time 0.00487	Train Loss 0.0956	
Epoch: [8][266/662]	Per Sample Total Time 0.00503	Per Sample Data Time 0.00016	Per Sample DNN Time 0.00487	Train Loss 0.0956	
Epoch: [8][366/662]	Per Sample Total Time 0.00498	Per Sample Data Time 0.00012	Per Sample DNN Time 0.00486	Train Loss 0.0954	
Epoch: [8][466/662]	Per Sample Total Time 0.00496	Per Sample Data Time 0.00009	Per Sample DNN Time 0.00486	Train Loss 0.0952	
Epoch: [8][566/662]	Per Sample Total Time 0.00494	Per Sample Data Time 0.00008	Per Sample DNN Time 0.00486	Train Loss 0.0951	
start validation
acc: 0.787496
AUC: 0.988140
Avg Precision: 0.152898
Avg Recall: 0.983553
d_prime: 3.198438
train_loss: 0.094976
valid_loss: 0.041402
validation finished
Epoch-8 lr: 0.0001305015625
epoch 8 training time: 445.998
---------------




Epoch: [9][4/662]	Per Sample Total Time 0.01290	Per Sample Data Time 0.00790	Per Sample DNN Time 0.00499	Train Loss 0.0926	
Epoch: [9][104/662]	Per Sample Total Time 0.00526	Per Sample Data Time 0.00038	Per Sample DNN Time 0.00488	Train Loss 0.0937	
Epoch: [9][204/662]	Per Sample Total Time 0.00507	Per Sample Data Time 0.00020	Per Sample DNN Time 0.00487	Train Loss 0.0936	
Epoch: [9][304/662]	Per Sample Total Time 0.00500	Per Sample Data Time 0.00013	Per Sample DNN Time 0.00487	Train Loss 0.0932	
Epoch: [9][404/662]	Per Sample Total Time 0.00496	Per Sample Data Time 0.00010	Per Sample DNN Time 0.00486	Train Loss 0.0932	
Epoch: [9][504/662]	Per Sample Total Time 0.00494	Per Sample Data Time 0.00008	Per Sample DNN Time 0.00486	Train Loss 0.0930	
Epoch: [9][604/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00007	Per Sample DNN Time 0.00486	Train Loss 0.0929	
start validation
acc: 0.813646
AUC: 0.989919
Avg Precision: 0.142750
Avg Recall: 0.988840
d_prime: 3.285674
train_loss: 



Epoch: [10][42/662]	Per Sample Total Time 0.00580	Per Sample Data Time 0.00092	Per Sample DNN Time 0.00488	Train Loss 0.0917	
Epoch: [10][142/662]	Per Sample Total Time 0.00515	Per Sample Data Time 0.00028	Per Sample DNN Time 0.00488	Train Loss 0.0917	
Epoch: [10][242/662]	Per Sample Total Time 0.00503	Per Sample Data Time 0.00017	Per Sample DNN Time 0.00487	Train Loss 0.0912	
Epoch: [10][342/662]	Per Sample Total Time 0.00498	Per Sample Data Time 0.00012	Per Sample DNN Time 0.00486	Train Loss 0.0910	
Epoch: [10][442/662]	Per Sample Total Time 0.00495	Per Sample Data Time 0.00009	Per Sample DNN Time 0.00486	Train Loss 0.0911	
Epoch: [10][542/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00008	Per Sample DNN Time 0.00486	Train Loss 0.0909	
Epoch: [10][642/662]	Per Sample Total Time 0.00492	Per Sample Data Time 0.00006	Per Sample DNN Time 0.00486	Train Loss 0.0908	
start validation
acc: 0.831079
AUC: 0.991295
Avg Precision: 0.151957
Avg Recall: 0.988305
d_prime: 3.362889
trai



Epoch: [11][80/662]	Per Sample Total Time 0.00536	Per Sample Data Time 0.00049	Per Sample DNN Time 0.00487	Train Loss 0.0896	
Epoch: [11][180/662]	Per Sample Total Time 0.00509	Per Sample Data Time 0.00022	Per Sample DNN Time 0.00487	Train Loss 0.0895	
Epoch: [11][280/662]	Per Sample Total Time 0.00500	Per Sample Data Time 0.00014	Per Sample DNN Time 0.00486	Train Loss 0.0895	
Epoch: [11][380/662]	Per Sample Total Time 0.00497	Per Sample Data Time 0.00011	Per Sample DNN Time 0.00486	Train Loss 0.0893	
Epoch: [11][480/662]	Per Sample Total Time 0.00494	Per Sample Data Time 0.00009	Per Sample DNN Time 0.00486	Train Loss 0.0894	
Epoch: [11][580/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00007	Per Sample DNN Time 0.00486	Train Loss 0.0893	
start validation
acc: 0.839295
AUC: 0.992733
Avg Precision: 0.156091
Avg Recall: 0.991426
d_prime: 3.456014
train_loss: 0.089319
valid_loss: 0.032713
validation finished
Epoch-11 lr: 8.014427207031248e-05
epoch 11 training time: 444.919
--



Epoch: [12][18/662]	Per Sample Total Time 0.00707	Per Sample Data Time 0.00215	Per Sample DNN Time 0.00492	Train Loss 0.0871	
Epoch: [12][118/662]	Per Sample Total Time 0.00521	Per Sample Data Time 0.00035	Per Sample DNN Time 0.00486	Train Loss 0.0879	
Epoch: [12][218/662]	Per Sample Total Time 0.00505	Per Sample Data Time 0.00019	Per Sample DNN Time 0.00486	Train Loss 0.0881	
Epoch: [12][318/662]	Per Sample Total Time 0.00499	Per Sample Data Time 0.00013	Per Sample DNN Time 0.00486	Train Loss 0.0881	
Epoch: [12][418/662]	Per Sample Total Time 0.00496	Per Sample Data Time 0.00010	Per Sample DNN Time 0.00486	Train Loss 0.0881	
Epoch: [12][518/662]	Per Sample Total Time 0.00494	Per Sample Data Time 0.00008	Per Sample DNN Time 0.00486	Train Loss 0.0880	
Epoch: [12][618/662]	Per Sample Total Time 0.00492	Per Sample Data Time 0.00007	Per Sample DNN Time 0.00486	Train Loss 0.0880	
start validation
acc: 0.850616
AUC: 0.993441
Avg Precision: 0.161787
Avg Recall: 0.992148
d_prime: 3.508011
trai



Epoch: [13][56/662]	Per Sample Total Time 0.00558	Per Sample Data Time 0.00071	Per Sample DNN Time 0.00487	Train Loss 0.0871	
Epoch: [13][156/662]	Per Sample Total Time 0.00512	Per Sample Data Time 0.00026	Per Sample DNN Time 0.00486	Train Loss 0.0872	
Epoch: [13][256/662]	Per Sample Total Time 0.00502	Per Sample Data Time 0.00016	Per Sample DNN Time 0.00486	Train Loss 0.0873	
Epoch: [13][356/662]	Per Sample Total Time 0.00497	Per Sample Data Time 0.00012	Per Sample DNN Time 0.00486	Train Loss 0.0872	
Epoch: [13][456/662]	Per Sample Total Time 0.00495	Per Sample Data Time 0.00009	Per Sample DNN Time 0.00486	Train Loss 0.0872	
Epoch: [13][556/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00008	Per Sample DNN Time 0.00486	Train Loss 0.0871	
Epoch: [13][656/662]	Per Sample Total Time 0.00492	Per Sample Data Time 0.00006	Per Sample DNN Time 0.00485	Train Loss 0.0871	
start validation
acc: 0.849013
AUC: 0.993584
Avg Precision: 0.169255
Avg Recall: 0.990889
d_prime: 3.519132
trai



Epoch: [14][94/662]	Per Sample Total Time 0.00533	Per Sample Data Time 0.00047	Per Sample DNN Time 0.00486	Train Loss 0.0863	
Epoch: [14][194/662]	Per Sample Total Time 0.00509	Per Sample Data Time 0.00023	Per Sample DNN Time 0.00486	Train Loss 0.0861	
Epoch: [14][294/662]	Per Sample Total Time 0.00501	Per Sample Data Time 0.00015	Per Sample DNN Time 0.00486	Train Loss 0.0860	
Epoch: [14][394/662]	Per Sample Total Time 0.00497	Per Sample Data Time 0.00011	Per Sample DNN Time 0.00486	Train Loss 0.0860	
Epoch: [14][494/662]	Per Sample Total Time 0.00495	Per Sample Data Time 0.00009	Per Sample DNN Time 0.00486	Train Loss 0.0861	
Epoch: [14][594/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00008	Per Sample DNN Time 0.00486	Train Loss 0.0860	
start validation
acc: 0.865645
AUC: 0.994108
Avg Precision: 0.183916
Avg Recall: 0.991201
d_prime: 3.561724
train_loss: 0.086149
valid_loss: 0.028464
validation finished
Epoch-14 lr: 4.921860108518065e-05
epoch 14 training time: 445.616
--



Epoch: [15][32/662]	Per Sample Total Time 0.00619	Per Sample Data Time 0.00130	Per Sample DNN Time 0.00489	Train Loss 0.0856	
Epoch: [15][132/662]	Per Sample Total Time 0.00519	Per Sample Data Time 0.00033	Per Sample DNN Time 0.00486	Train Loss 0.0856	
Epoch: [15][232/662]	Per Sample Total Time 0.00504	Per Sample Data Time 0.00019	Per Sample DNN Time 0.00486	Train Loss 0.0853	
Epoch: [15][332/662]	Per Sample Total Time 0.00499	Per Sample Data Time 0.00013	Per Sample DNN Time 0.00486	Train Loss 0.0852	
Epoch: [15][432/662]	Per Sample Total Time 0.00496	Per Sample Data Time 0.00010	Per Sample DNN Time 0.00486	Train Loss 0.0851	
Epoch: [15][532/662]	Per Sample Total Time 0.00494	Per Sample Data Time 0.00008	Per Sample DNN Time 0.00486	Train Loss 0.0850	
Epoch: [15][632/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00007	Per Sample DNN Time 0.00486	Train Loss 0.0851	
start validation
acc: 0.868149
AUC: 0.994447
Avg Precision: 0.171821
Avg Recall: 0.991884
d_prime: 3.591173
trai



Epoch: [16][70/662]	Per Sample Total Time 0.00545	Per Sample Data Time 0.00058	Per Sample DNN Time 0.00487	Train Loss 0.0841	
Epoch: [16][170/662]	Per Sample Total Time 0.00510	Per Sample Data Time 0.00024	Per Sample DNN Time 0.00486	Train Loss 0.0845	
Epoch: [16][270/662]	Per Sample Total Time 0.00501	Per Sample Data Time 0.00016	Per Sample DNN Time 0.00486	Train Loss 0.0847	
Epoch: [16][370/662]	Per Sample Total Time 0.00497	Per Sample Data Time 0.00011	Per Sample DNN Time 0.00486	Train Loss 0.0846	
Epoch: [16][470/662]	Per Sample Total Time 0.00495	Per Sample Data Time 0.00009	Per Sample DNN Time 0.00486	Train Loss 0.0846	
Epoch: [16][570/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00008	Per Sample DNN Time 0.00486	Train Loss 0.0846	
start validation
acc: 0.870053
AUC: 0.994638
Avg Precision: 0.172907
Avg Recall: 0.993313
d_prime: 3.608478
train_loss: 0.084520
valid_loss: 0.027097
validation finished
Epoch-16 lr: 3.556043928404302e-05
epoch 16 training time: 445.340
--



Epoch: [17][8/662]	Per Sample Total Time 0.00947	Per Sample Data Time 0.00451	Per Sample DNN Time 0.00496	Train Loss 0.0841	
Epoch: [17][108/662]	Per Sample Total Time 0.00524	Per Sample Data Time 0.00037	Per Sample DNN Time 0.00486	Train Loss 0.0846	
Epoch: [17][208/662]	Per Sample Total Time 0.00505	Per Sample Data Time 0.00020	Per Sample DNN Time 0.00486	Train Loss 0.0844	
Epoch: [17][308/662]	Per Sample Total Time 0.00499	Per Sample Data Time 0.00013	Per Sample DNN Time 0.00486	Train Loss 0.0843	
Epoch: [17][408/662]	Per Sample Total Time 0.00496	Per Sample Data Time 0.00010	Per Sample DNN Time 0.00486	Train Loss 0.0843	
Epoch: [17][508/662]	Per Sample Total Time 0.00494	Per Sample Data Time 0.00008	Per Sample DNN Time 0.00486	Train Loss 0.0841	
Epoch: [17][608/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00007	Per Sample DNN Time 0.00486	Train Loss 0.0840	
start validation
acc: 0.874461
AUC: 0.994886
Avg Precision: 0.178615
Avg Recall: 0.993522
d_prime: 3.631712
train



Epoch: [18][46/662]	Per Sample Total Time 0.00572	Per Sample Data Time 0.00085	Per Sample DNN Time 0.00487	Train Loss 0.0842	
Epoch: [18][146/662]	Per Sample Total Time 0.00513	Per Sample Data Time 0.00027	Per Sample DNN Time 0.00486	Train Loss 0.0839	
Epoch: [18][246/662]	Per Sample Total Time 0.00502	Per Sample Data Time 0.00016	Per Sample DNN Time 0.00486	Train Loss 0.0839	
Epoch: [18][346/662]	Per Sample Total Time 0.00497	Per Sample Data Time 0.00012	Per Sample DNN Time 0.00485	Train Loss 0.0838	
Epoch: [18][446/662]	Per Sample Total Time 0.00495	Per Sample Data Time 0.00009	Per Sample DNN Time 0.00486	Train Loss 0.0837	
Epoch: [18][546/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00008	Per Sample DNN Time 0.00486	Train Loss 0.0836	
Epoch: [18][646/662]	Per Sample Total Time 0.00492	Per Sample Data Time 0.00006	Per Sample DNN Time 0.00485	Train Loss 0.0835	
start validation
acc: 0.877467
AUC: 0.995062
Avg Precision: 0.177008
Avg Recall: 0.993388
d_prime: 3.648861
trai



Epoch: [19][84/662]	Per Sample Total Time 0.00535	Per Sample Data Time 0.00048	Per Sample DNN Time 0.00486	Train Loss 0.0830	
Epoch: [19][184/662]	Per Sample Total Time 0.00508	Per Sample Data Time 0.00022	Per Sample DNN Time 0.00486	Train Loss 0.0832	
Epoch: [19][284/662]	Per Sample Total Time 0.00500	Per Sample Data Time 0.00015	Per Sample DNN Time 0.00486	Train Loss 0.0832	
Epoch: [19][384/662]	Per Sample Total Time 0.00496	Per Sample Data Time 0.00011	Per Sample DNN Time 0.00486	Train Loss 0.0833	
Epoch: [19][484/662]	Per Sample Total Time 0.00494	Per Sample Data Time 0.00009	Per Sample DNN Time 0.00485	Train Loss 0.0832	
Epoch: [19][584/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00007	Per Sample DNN Time 0.00486	Train Loss 0.0833	
start validation
acc: 0.879471
AUC: 0.995279
Avg Precision: 0.189474
Avg Recall: 0.993201
d_prime: 3.670778
train_loss: 0.083203
valid_loss: 0.025681
validation finished
Epoch-19 lr: 2.1838554775312915e-05
epoch 19 training time: 446.218
-



Epoch: [20][22/662]	Per Sample Total Time 0.00675	Per Sample Data Time 0.00185	Per Sample DNN Time 0.00490	Train Loss 0.0831	
Epoch: [20][122/662]	Per Sample Total Time 0.00521	Per Sample Data Time 0.00035	Per Sample DNN Time 0.00486	Train Loss 0.0823	
Epoch: [20][222/662]	Per Sample Total Time 0.00505	Per Sample Data Time 0.00019	Per Sample DNN Time 0.00486	Train Loss 0.0826	
Epoch: [20][322/662]	Per Sample Total Time 0.00499	Per Sample Data Time 0.00013	Per Sample DNN Time 0.00486	Train Loss 0.0825	
Epoch: [20][422/662]	Per Sample Total Time 0.00496	Per Sample Data Time 0.00010	Per Sample DNN Time 0.00485	Train Loss 0.0826	
Epoch: [20][522/662]	Per Sample Total Time 0.00494	Per Sample Data Time 0.00008	Per Sample DNN Time 0.00485	Train Loss 0.0827	
Epoch: [20][622/662]	Per Sample Total Time 0.00493	Per Sample Data Time 0.00007	Per Sample DNN Time 0.00485	Train Loss 0.0827	
start validation
acc: 0.879271
AUC: 0.995434
Avg Precision: 0.178181
Avg Recall: 0.994049
d_prime: 3.686999
trai

In [None]:
# # Model checker
# ast_mdl.train()

# for i, (audio_input, labels) in enumerate(train_loader):

#     print(str(i)+'th instance')

#     B = audio_input.size(0)
#     audio_input = audio_input.to(device, non_blocking=True)
#     labels = labels.to(device, non_blocking=True)

#     with autocast():
#         audio_output = ast_mdl(audio_input)
#     print("Output shape", audio_output.shape)
#     print("Done")
#     break

In [None]:
args.exp_dir + '/models/best_audio_model.pth'

'/content/drive/MyDrive/Thesis/resout_qast/models/best_audio_model.pth'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sd = torch.load(args.exp_dir + '/models/best_audio_model.pth', map_location=device)
audio_model = torch.nn.DataParallel(ast_mdl)
audio_model.load_state_dict(sd)

<All keys matched successfully>

In [None]:
args.loss

'BCE'

In [None]:
loss_fn = nn.BCEWithLogitsLoss()
args.loss_fn = loss_fn

In [None]:
import numpy as np
# best model on the validation set
stats, _ = validate(audio_model, val_loader, args, 'valid_set')
# note it is NOT mean of class-wise accuracy
val_acc = stats[0]['acc']
val_mAUC = np.mean([stat['auc'] for stat in stats])
print('---------------evaluate on the validation set---------------')
print("Accuracy: {:.6f}".format(val_acc))
print("AUC: {:.6f}".format(val_mAUC))

---------------evaluate on the validation set---------------
Accuracy: 0.879371
AUC: 0.995279


In [None]:
stats[0]

{'precisions': array([0.01499318, 0.07860886]),
 'recalls': array([1., 1.]),
 'AP': 0.9278634102889988,
 'fpr': array([0.        , 0.61411439]),
 'fnr': array([1., 0.]),
 'auc': 0.9971687353237169,
 'acc': 0.8675147660154475}

In [None]:
# test the model on the evaluation set
stats, _ = validate(audio_model, eval_loader, args, 'eval_set')
eval_acc = stats[0]['acc']
eval_mAUC = np.mean([stat['auc'] for stat in stats])
print('---------------evaluate on the test set---------------')
print("Accuracy: {:.6f}".format(eval_acc))
print("AUC: {:.6f}".format(eval_mAUC))
np.savetxt(args.exp_dir + '/eval_result.csv', [val_acc, val_mAUC, eval_acc, eval_mAUC])

---------------evaluate on the test set---------------
Accuracy: 0.867515
AUC: 0.994489


# Pretraining QAST on Images

In [20]:
"""
Preprocessing:
img and label
for label int to one hot convert is below
"""

from torch.utils.data import DataLoader
from torchvision import transforms

def transforms_(examples):
    if 'img' in examples:
      examples["image"] = [image.convert("RGB").resize((128,128)) for image in examples['img']] #32
      return examples
    elif 'image' in examples:
      examples["image"] = [image.convert("RGB").resize((64,64)) for image in examples['image']]
      return examples
    else:
      print("please check the dataset keys")
    

def collate_fn(examples):
    images = []
    labels = []
    convert_tensor = transforms.ToTensor()

    for example in examples:
        images.append(convert_tensor(example["image"]))
        labels.append(example["label"])
        
    pixel_values = torch.stack(images)
    labels = torch.tensor(labels)

    b_size = labels.shape[0]
    n_classes = args_pretraining.n_classes
    y = torch.zeros(b_size, n_classes)
    y[range(y.shape[0]), labels]=1
    return (pixel_values, y)

In [29]:
import torch
import torchvision
import timm
from enum import Enum
from typing import Union

class Format(str, Enum):
    NCHW = 'NCHW'
    NHWC = 'NHWC'
    NCL = 'NCL'
    NLC = 'NLC'


FormatT = Union[str, Format]


def get_spatial_dim(fmt: FormatT):
    fmt = Format(fmt)
    if fmt is Format.NLC:
        dim = (1,)
    elif fmt is Format.NCL:
        dim = (2,)
    elif fmt is Format.NHWC:
        dim = (1, 2)
    else:
        dim = (2, 3)
    return dim


def get_channel_dim(fmt: FormatT):
    fmt = Format(fmt)
    if fmt is Format.NHWC:
        dim = 3
    elif fmt is Format.NLC:
        dim = 2
    else:
        dim = 1
    return dim


def nchw_to(x: torch.Tensor, fmt: Format):
    if fmt == Format.NHWC:
        x = x.permute(0, 2, 3, 1)
    elif fmt == Format.NLC:
        x = x.flatten(2).transpose(1, 2)
    elif fmt == Format.NCL:
        x = x.flatten(2)
    return x


def nhwc_to(x: torch.Tensor, fmt: Format):
    if fmt == Format.NCHW:
        x = x.permute(0, 3, 1, 2)
    elif fmt == Format.NLC:
        x = x.flatten(1, 2)
    elif fmt == Format.NCL:
        x = x.flatten(1, 2).transpose(1, 2)
    return

import logging
from typing import List, Optional, Callable

import torch
from torch import nn as nn
from torch import _assert
import torch.nn.functional as F

_logger = logging.getLogger(__name__)


class PatchEmbedding(nn.Module):
    """ 2D Image to Patch Embedding
    """
    output_fmt: Format

    def __init__(
            self,
            img_size: Optional[int] = 32,
            patch_size: int = 16,
            in_chans: int = 4,
            embed_dim: int = 768,
            norm_layer: Optional[Callable] = None,
            flatten: bool = True,
            output_fmt: Optional[str] = None,
            bias: bool = True,
    ):
        super().__init__()
        self.patch_size = to_2tuple(patch_size)
        if img_size is not None:
            self.img_size = to_2tuple(img_size)
            self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
            self.num_patches = self.grid_size[0] * self.grid_size[1]
        else:
            self.img_size = None
            self.grid_size = None
            self.num_patches = None

        if output_fmt is not None:
            self.flatten = False
            self.output_fmt = Format(output_fmt)
        else:
            # flatten spatial dim and transpose to channels last, kept for bwd compat
            self.flatten = flatten
            self.output_fmt = Format.NCHW

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        if self.img_size is not None:
            _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
            _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
        
        zeros = torch.zeros((B, 1, H, W)).to(device)
        x = torch.cat((zeros, x), 1)
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # NCHW -> NLC
        elif self.output_fmt != Format.NCHW:
            x = nchw_to(x, self.output_fmt)
        x = self.norm(x)
        return x



In [30]:
# create a folder for results: predictions and models
# create target.csv

import os
import csv
import shutil

def create_target_csv(path, val_data):
  out_target = []
  for test in val_data:
    out_target.extend(test[1].tolist())
  # print(len(out_target))

  with open(os.path.join(path, "target.csv"), 'w') as f:
    write = csv.writer(f)
    write.writerows(out_target)

def prepare_result_saving(args, val_data = None):
  if not os.path.exists(args.exp_dir, ):
      # if the demo_folder directory is not present 
      # then create it.
      os.makedirs(args.exp_dir)
      os.makedirs(os.path.join(args.exp_dir, "models"))
      os.makedirs(os.path.join(args.exp_dir, "predictions"))
      print("Created folders")
  else:
      print("Folders already exists")
  if not os.path.exists(os.path.join(args.exp_dir, "predictions","target.csv")):
    if val_data!= None:
      create_target_csv(os.path.join(args.exp_dir, "predictions"), val_data)
    else:
      shutil.copy('/content/drive/MyDrive/Thesis/resout_qast/predictions/target.csv', os.path.join(args.exp_dir, "predictions") ) 
      print("copied target")
  else:
      print("Target.csv already exists")


In [31]:
class QVIT(nn.Module):
  def __init__(self, label_dim=200, img_size=32, verbose=True):
    super(QVIT, self).__init__()
    # automatcially get the intermediate shape
    self.original_embedding_dim = 768
    num_heads = 12
    mlp_ratio = 4.
    qkv_bias = True
    qk_norm = False
    drop_rate = 0.1
    attn_drop_rate = 0.
    depth = 12
    self.patch_embed = PatchEmbedding(
            img_size = img_size,
            patch_size=16,
            in_chans=4,
            embed_dim=self.original_embedding_dim,
        )
    num_patches = self.patch_embed.num_patches


    self.cls_token = nn.Parameter(torch.zeros(1, 1, self.original_embedding_dim))
    self.dist_token =  nn.Parameter(torch.zeros(1, 1, self.original_embedding_dim))
    # TODO pretrained or sinusoidal
    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.original_embedding_dim))
    
    trunc_normal_(self.pos_embed, std=.02)
    self.pos_drop = nn.Dropout(p=0.)

    self.blocks = nn.Sequential(*[
            Block(
                self.original_embedding_dim,
                num_heads,
                mlp_ratio = mlp_ratio,
                qkv_bias = qkv_bias,
                qk_norm = qk_norm,
                drop = drop_rate,
                attn_drop = attn_drop_rate,
            )
            for i in range(depth)])
    self.norm =  Norm(self.original_embedding_dim)

    # Classifier Head
    self.fc_norm = Norm(self.original_embedding_dim) 
    self.head = nn.Linear(self.original_embedding_dim, label_dim) if label_dim > 0 else nn.Identity()

  
  @autocast()
  def forward(self, x):
    B = x.shape[0]
    x = self.patch_embed(x)
    
    cls_tokens = self.cls_token.expand(B, -1, -1)
    dist_token = self.dist_token.expand(B, -1, -1)

    x = torch.cat((cls_tokens, dist_token, x), dim=1)
    x = x + self.pos_embed
    x = self.pos_drop(x)

    x = self.blocks(x)
    x = self.norm(x)

    x = (x[:, 0] + x[:, 1]) / 2
    x = self.fc_norm(x)
    x = self.head(x)
    return x
     


## CIFAR 10

In [38]:
from datasets import load_dataset, Image
dset = load_dataset('cifar10', split='train', streaming=True).cast_column("image", Image())
dset_test = load_dataset('cifar10', split='test', streaming=True, use_auth_token=True).cast_column("image", Image())




In [39]:
#CIFAR10 on huggingface doesn't have eval set

dset = dset.map(transforms_, batched=True)
dset_test = dset_test.map(transforms_, batched=True)

dset_iter = dset.with_format("torch")
dset_test_iter = dset_test.with_format("torch")

In [34]:
# class Arguments():
#   bal=None
#   lr=2.5e-4
#   mixup=0.6
#   noise=True
#   num_workers = 32
#   optimizer = 'adam'
#   metrics='acc'
#   loss='BCE'              
#   lrscheduler_start=5
#   lrscheduler_step=1
#   lrscheduler_decay=0.85
#   warmup = False
#   wa = False
#   wa_start = 1
#   wa_end = 5
#   lr_patience = 2
#   save_model = True

args_pretraining = Arguments()
args_pretraining.exp_dir = '/content/drive/MyDrive/Thesis/pretrain_qvit_CIFAR10'
args_pretraining.n_classes = 10
args_pretraining.img_size = 128

args_pretraining.batch_size = 128
args_pretraining.n_epochs = 30
args_pretraining.lr = 2.5e-4
args_pretraining.warmup = False

# args_pretraining.n_print_steps = 100
# args_pretraining.dataset_mean = 0
# args_pretraining.dataset_std = 0
# args_pretraining.num_workers = 8
args_pretraining.mixup = 0.6
# args_pretraining.wa = True
# args_pretraining.lrscheduler_start=0
# args_pretraining.lrscheduler_step=1
# args_pretraining.lrscheduler_decay=0.85

In [42]:
ast_mdl_not_pretrained = QVIT(label_dim=args_pretraining.n_classes, img_size = args_pretraining.img_size)

In [43]:
train_loader_cifar10 = DataLoader(dset_iter, collate_fn=collate_fn, batch_size = args_pretraining.batch_size, pin_memory=True)
test_loader_cifar10 = DataLoader(dset_test_iter, collate_fn=collate_fn, batch_size=args_pretraining.batch_size, pin_memory=True)


In [44]:
prepare_result_saving(args_pretraining, val_data = test_loader_cifar10)

Folders already exists
Target.csv already exists


In [45]:
ast_mdl_not_pretrained.to(device)

QVIT(
  (patch_embed): PatchEmbedding(
    (proj): Conv2d(4, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): Norm()
      (attn): Attention(
        (qkv): QuaternionLinearAutograd(in_features=192, out_features=576, bias=True, init_criterion=glorot, weight_init=quaternion, rotation=False, seed=69)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): QuaternionLinearAutograd(in_features=192, out_features=192, bias=True, init_criterion=glorot, weight_init=quaternion, rotation=False, seed=388)
        (proj_drop): Dropout(p=0.1, inplace=False)
      )
      (norm2): Norm()
      (mlp): Mlp(
        (fc1): QuaternionLinearAutograd(in_features=192, out_features=768, bias=True, init_criterion=glorot, weight_init=quaternion, rotation=False, seed=548)
        (drop1): Dropout(p=0.1, inplace=False)
     

In [46]:
# args_pretraining.lr = 2.5e-4
print('Now starting training for {:d} epochs'.format(args_pretraining.n_epochs))
from src.traintest_cust import train as train_cust
from src.traintest_cust import validate as validate_cust 
train_cust(ast_mdl_not_pretrained, train_loader_cifar10, test_loader_cifar10, args_pretraining)

Now starting training for 30 epochs
running on cuda
Total parameter number is : 22.174 million
Total trainable parameter number is : 22.174 million
now training with speechcommands, main metrics: acc, loss function: BCEWithLogitsLoss(), learning rate scheduler: <torch.optim.lr_scheduler.MultiStepLR object at 0x7fadf153e380>
The learning rate scheduler starts at 5 epoch with decay rate of 0.850 every 1 epochs
current #steps=0, #epochs=1
start training...
---------------
2023-06-07 16:28:11.711850
current #epochs=1, #steps=0




Epoch: [1][100/unk]	Per Sample Total Time 0.00458	Per Sample Data Time 0.00144	Per Sample DNN Time 0.00314	Train Loss 0.3309	
Epoch: [1][200/unk]	Per Sample Total Time 0.00449	Per Sample Data Time 0.00135	Per Sample DNN Time 0.00314	Train Loss 0.3237	
Epoch: [1][300/unk]	Per Sample Total Time 0.00443	Per Sample Data Time 0.00129	Per Sample DNN Time 0.00314	Train Loss 0.3169	
start validation
acc: 0.250300
AUC: 0.728984
Avg Precision: 0.189392
Avg Recall: 0.588300
d_prime: 0.862307
train_loss: 0.312160
valid_loss: 0.730730
validation finished
Epoch-1 lr: 0.00025
epoch 1 training time: 253.772
---------------
2023-06-07 16:32:25.484586
current #epochs=2, #steps=391




Epoch: [2][9/unk]	Per Sample Total Time 0.00626	Per Sample Data Time 0.00310	Per Sample DNN Time 0.00316	Train Loss 0.2963	
Epoch: [2][109/unk]	Per Sample Total Time 0.00459	Per Sample Data Time 0.00145	Per Sample DNN Time 0.00314	Train Loss 0.2912	
Epoch: [2][209/unk]	Per Sample Total Time 0.00442	Per Sample Data Time 0.00129	Per Sample DNN Time 0.00314	Train Loss 0.2907	
Epoch: [2][309/unk]	Per Sample Total Time 0.00442	Per Sample Data Time 0.00128	Per Sample DNN Time 0.00314	Train Loss 0.2886	
start validation
acc: 0.312300
AUC: 0.780988
Avg Precision: 0.190658
Avg Recall: 0.798900
d_prime: 1.096770
train_loss: 0.286902
valid_loss: 0.726345
validation finished
Epoch-2 lr: 0.00025
epoch 2 training time: 252.541
---------------
2023-06-07 16:36:38.025529
current #epochs=3, #steps=782




Epoch: [3][18/unk]	Per Sample Total Time 0.00530	Per Sample Data Time 0.00216	Per Sample DNN Time 0.00314	Train Loss 0.2797	
Epoch: [3][118/unk]	Per Sample Total Time 0.00455	Per Sample Data Time 0.00141	Per Sample DNN Time 0.00314	Train Loss 0.2768	
Epoch: [3][218/unk]	Per Sample Total Time 0.00443	Per Sample Data Time 0.00130	Per Sample DNN Time 0.00314	Train Loss 0.2765	
Epoch: [3][318/unk]	Per Sample Total Time 0.00445	Per Sample Data Time 0.00131	Per Sample DNN Time 0.00314	Train Loss 0.2743	
start validation
acc: 0.340800
AUC: 0.815974
Avg Precision: 0.193000
Avg Recall: 0.883700
d_prime: 1.272975
train_loss: 0.273018
valid_loss: 0.723418
validation finished
Epoch-3 lr: 0.00025
epoch 3 training time: 253.162
---------------
2023-06-07 16:40:51.187344
current #epochs=4, #steps=1173




Epoch: [4][27/unk]	Per Sample Total Time 0.00512	Per Sample Data Time 0.00192	Per Sample DNN Time 0.00320	Train Loss 0.2671	
Epoch: [4][127/unk]	Per Sample Total Time 0.00457	Per Sample Data Time 0.00142	Per Sample DNN Time 0.00315	Train Loss 0.2611	
Epoch: [4][227/unk]	Per Sample Total Time 0.00445	Per Sample Data Time 0.00131	Per Sample DNN Time 0.00314	Train Loss 0.2596	
Epoch: [4][327/unk]	Per Sample Total Time 0.00444	Per Sample Data Time 0.00130	Per Sample DNN Time 0.00314	Train Loss 0.2567	
start validation
acc: 0.382900
AUC: 0.846799
Avg Precision: 0.251329
Avg Recall: 0.845100
d_prime: 1.446461
train_loss: 0.255825
valid_loss: 0.718725
validation finished
Epoch-4 lr: 0.00025
epoch 4 training time: 253.034
---------------
2023-06-07 16:45:04.222408
current #epochs=5, #steps=1564




Epoch: [5][36/unk]	Per Sample Total Time 0.00463	Per Sample Data Time 0.00149	Per Sample DNN Time 0.00314	Train Loss 0.2470	
Epoch: [5][136/unk]	Per Sample Total Time 0.00447	Per Sample Data Time 0.00133	Per Sample DNN Time 0.00314	Train Loss 0.2469	
Epoch: [5][236/unk]	Per Sample Total Time 0.00447	Per Sample Data Time 0.00133	Per Sample DNN Time 0.00314	Train Loss 0.2460	
Epoch: [5][336/unk]	Per Sample Total Time 0.00442	Per Sample Data Time 0.00128	Per Sample DNN Time 0.00314	Train Loss 0.2436	
start validation
acc: 0.437100
AUC: 0.866172
Avg Precision: 0.264803
Avg Recall: 0.847100
d_prime: 1.567621
train_loss: 0.242411
valid_loss: 0.716060
validation finished
Epoch-5 lr: 0.0002125
epoch 5 training time: 251.276
---------------
2023-06-07 16:49:15.498457
current #epochs=6, #steps=1955




Epoch: [6][45/unk]	Per Sample Total Time 0.00451	Per Sample Data Time 0.00137	Per Sample DNN Time 0.00313	Train Loss 0.2328	
Epoch: [6][145/unk]	Per Sample Total Time 0.00444	Per Sample Data Time 0.00130	Per Sample DNN Time 0.00314	Train Loss 0.2313	
Epoch: [6][245/unk]	Per Sample Total Time 0.00448	Per Sample Data Time 0.00135	Per Sample DNN Time 0.00314	Train Loss 0.2296	
Epoch: [6][345/unk]	Per Sample Total Time 0.00446	Per Sample Data Time 0.00133	Per Sample DNN Time 0.00314	Train Loss 0.2278	
start validation
acc: 0.482300
AUC: 0.884143
Avg Precision: 0.366739
Avg Recall: 0.724000
d_prime: 1.691340
train_loss: 0.226907
valid_loss: 0.711885
validation finished
Epoch-6 lr: 0.00018062499999999999
epoch 6 training time: 254.670
---------------
2023-06-07 16:53:30.168968
current #epochs=7, #steps=2346




Epoch: [7][54/unk]	Per Sample Total Time 0.00455	Per Sample Data Time 0.00142	Per Sample DNN Time 0.00313	Train Loss 0.2174	
Epoch: [7][154/unk]	Per Sample Total Time 0.00443	Per Sample Data Time 0.00129	Per Sample DNN Time 0.00314	Train Loss 0.2160	
Epoch: [7][254/unk]	Per Sample Total Time 0.00446	Per Sample Data Time 0.00132	Per Sample DNN Time 0.00314	Train Loss 0.2151	
Epoch: [7][354/unk]	Per Sample Total Time 0.00441	Per Sample Data Time 0.00128	Per Sample DNN Time 0.00314	Train Loss 0.2136	
start validation
acc: 0.506100
AUC: 0.895101
Avg Precision: 0.346144
Avg Recall: 0.791000
d_prime: 1.773597
train_loss: 0.212887
valid_loss: 0.710876
validation finished
Epoch-7 lr: 0.00015353125
epoch 7 training time: 252.929
---------------
2023-06-07 16:57:43.098568
current #epochs=8, #steps=2737




Epoch: [8][63/unk]	Per Sample Total Time 0.00448	Per Sample Data Time 0.00135	Per Sample DNN Time 0.00313	Train Loss 0.2077	
Epoch: [8][163/unk]	Per Sample Total Time 0.00447	Per Sample Data Time 0.00134	Per Sample DNN Time 0.00313	Train Loss 0.2035	
Epoch: [8][263/unk]	Per Sample Total Time 0.00444	Per Sample Data Time 0.00131	Per Sample DNN Time 0.00313	Train Loss 0.2046	
Epoch: [8][363/unk]	Per Sample Total Time 0.00443	Per Sample Data Time 0.00130	Per Sample DNN Time 0.00313	Train Loss 0.2026	
start validation
acc: 0.538500
AUC: 0.905268
Avg Precision: 0.393661
Avg Recall: 0.790100
d_prime: 1.855680
train_loss: 0.202098
valid_loss: 0.707527
validation finished
Epoch-8 lr: 0.0001305015625
epoch 8 training time: 258.104
---------------
2023-06-07 17:02:01.203141
current #epochs=9, #steps=3128




Epoch: [9][72/unk]	Per Sample Total Time 0.00479	Per Sample Data Time 0.00165	Per Sample DNN Time 0.00313	Train Loss 0.1968	
Epoch: [9][172/unk]	Per Sample Total Time 0.00474	Per Sample Data Time 0.00161	Per Sample DNN Time 0.00313	Train Loss 0.1936	
Epoch: [9][272/unk]	Per Sample Total Time 0.00457	Per Sample Data Time 0.00143	Per Sample DNN Time 0.00313	Train Loss 0.1943	
Epoch: [9][372/unk]	Per Sample Total Time 0.00449	Per Sample Data Time 0.00136	Per Sample DNN Time 0.00313	Train Loss 0.1925	
start validation
acc: 0.553000
AUC: 0.911129
Avg Precision: 0.403010
Avg Recall: 0.800700
d_prime: 1.905995
train_loss: 0.192287
valid_loss: 0.705328
validation finished
Epoch-9 lr: 0.00011092632812499999
epoch 9 training time: 257.577
---------------
2023-06-07 17:06:18.780108
current #epochs=10, #steps=3519




Epoch: [10][81/unk]	Per Sample Total Time 0.00476	Per Sample Data Time 0.00163	Per Sample DNN Time 0.00313	Train Loss 0.1899	
Epoch: [10][181/unk]	Per Sample Total Time 0.00449	Per Sample Data Time 0.00136	Per Sample DNN Time 0.00313	Train Loss 0.1862	
Epoch: [10][281/unk]	Per Sample Total Time 0.00442	Per Sample Data Time 0.00129	Per Sample DNN Time 0.00313	Train Loss 0.1857	
Epoch: [10][381/unk]	Per Sample Total Time 0.00439	Per Sample Data Time 0.00126	Per Sample DNN Time 0.00313	Train Loss 0.1838	
start validation
acc: 0.566500
AUC: 0.916452
Avg Precision: 0.402657
Avg Recall: 0.817700
d_prime: 1.953871
train_loss: 0.183753
valid_loss: 0.705096
validation finished
Epoch-10 lr: 9.428737890624999e-05
epoch 10 training time: 254.011
---------------
2023-06-07 17:10:32.790915
current #epochs=11, #steps=3910




Epoch: [11][90/unk]	Per Sample Total Time 0.00472	Per Sample Data Time 0.00158	Per Sample DNN Time 0.00314	Train Loss 0.1808	
Epoch: [11][190/unk]	Per Sample Total Time 0.00449	Per Sample Data Time 0.00136	Per Sample DNN Time 0.00313	Train Loss 0.1789	
Epoch: [11][290/unk]	Per Sample Total Time 0.00444	Per Sample Data Time 0.00131	Per Sample DNN Time 0.00313	Train Loss 0.1783	
Epoch: [11][390/unk]	Per Sample Total Time 0.00439	Per Sample Data Time 0.00126	Per Sample DNN Time 0.00313	Train Loss 0.1765	
start validation
acc: 0.572600
AUC: 0.919337
Avg Precision: 0.430325
Avg Recall: 0.795800
d_prime: 1.980787
train_loss: 0.176456
valid_loss: 0.702602
validation finished
Epoch-11 lr: 8.014427207031248e-05
epoch 11 training time: 252.860
---------------
2023-06-07 17:14:45.650641
current #epochs=12, #steps=4301




Epoch: [12][99/unk]	Per Sample Total Time 0.00472	Per Sample Data Time 0.00158	Per Sample DNN Time 0.00314	Train Loss 0.1746	
Epoch: [12][199/unk]	Per Sample Total Time 0.00451	Per Sample Data Time 0.00138	Per Sample DNN Time 0.00313	Train Loss 0.1726	
Epoch: [12][299/unk]	Per Sample Total Time 0.00445	Per Sample Data Time 0.00132	Per Sample DNN Time 0.00313	Train Loss 0.1719	
start validation
acc: 0.583500
AUC: 0.921336
Avg Precision: 0.375038
Avg Recall: 0.863600
d_prime: 1.999857
train_loss: 0.170303
valid_loss: 0.701608
validation finished
Epoch-12 lr: 6.81226312597656e-05
epoch 12 training time: 253.111
---------------
2023-06-07 17:18:58.761469
current #epochs=13, #steps=4692




Epoch: [13][8/unk]	Per Sample Total Time 0.00637	Per Sample Data Time 0.00324	Per Sample DNN Time 0.00314	Train Loss 0.1723	
Epoch: [13][108/unk]	Per Sample Total Time 0.00453	Per Sample Data Time 0.00138	Per Sample DNN Time 0.00315	Train Loss 0.1692	
Epoch: [13][208/unk]	Per Sample Total Time 0.00443	Per Sample Data Time 0.00129	Per Sample DNN Time 0.00314	Train Loss 0.1670	
Epoch: [13][308/unk]	Per Sample Total Time 0.00442	Per Sample Data Time 0.00128	Per Sample DNN Time 0.00314	Train Loss 0.1662	
start validation
acc: 0.587700
AUC: 0.922805
Avg Precision: 0.386108
Avg Recall: 0.839200
d_prime: 2.014120
train_loss: 0.165159
valid_loss: 0.701781
validation finished
Epoch-13 lr: 5.7904236570800764e-05
epoch 13 training time: 252.616
---------------
2023-06-07 17:23:11.376896
current #epochs=14, #steps=5083




Epoch: [14][17/unk]	Per Sample Total Time 0.00532	Per Sample Data Time 0.00217	Per Sample DNN Time 0.00315	Train Loss 0.1647	
Epoch: [14][117/unk]	Per Sample Total Time 0.00456	Per Sample Data Time 0.00143	Per Sample DNN Time 0.00314	Train Loss 0.1634	
Epoch: [14][217/unk]	Per Sample Total Time 0.00442	Per Sample Data Time 0.00129	Per Sample DNN Time 0.00314	Train Loss 0.1613	
Epoch: [14][317/unk]	Per Sample Total Time 0.00444	Per Sample Data Time 0.00130	Per Sample DNN Time 0.00314	Train Loss 0.1602	
start validation
acc: 0.602000
AUC: 0.927483
Avg Precision: 0.428039
Avg Recall: 0.831000
d_prime: 2.060927
train_loss: 0.159385
valid_loss: 0.700826
validation finished
Epoch-14 lr: 4.921860108518065e-05
epoch 14 training time: 252.269
---------------
2023-06-07 17:27:23.645861
current #epochs=15, #steps=5474




Epoch: [15][26/unk]	Per Sample Total Time 0.00557	Per Sample Data Time 0.00242	Per Sample DNN Time 0.00314	Train Loss 0.1601	
Epoch: [15][126/unk]	Per Sample Total Time 0.00465	Per Sample Data Time 0.00152	Per Sample DNN Time 0.00314	Train Loss 0.1563	
Epoch: [15][226/unk]	Per Sample Total Time 0.00454	Per Sample Data Time 0.00140	Per Sample DNN Time 0.00314	Train Loss 0.1551	
Epoch: [15][326/unk]	Per Sample Total Time 0.00452	Per Sample Data Time 0.00139	Per Sample DNN Time 0.00314	Train Loss 0.1542	
start validation
acc: 0.611400
AUC: 0.929837
Avg Precision: 0.407603
Avg Recall: 0.846500
d_prime: 2.085365
train_loss: 0.153700
valid_loss: 0.699597
validation finished
Epoch-15 lr: 4.183581092240355e-05
epoch 15 training time: 255.976
---------------
2023-06-07 17:31:39.621716
current #epochs=16, #steps=5865




Epoch: [16][35/unk]	Per Sample Total Time 0.00474	Per Sample Data Time 0.00160	Per Sample DNN Time 0.00314	Train Loss 0.1541	
Epoch: [16][135/unk]	Per Sample Total Time 0.00450	Per Sample Data Time 0.00136	Per Sample DNN Time 0.00314	Train Loss 0.1515	
Epoch: [16][235/unk]	Per Sample Total Time 0.00452	Per Sample Data Time 0.00138	Per Sample DNN Time 0.00314	Train Loss 0.1507	
Epoch: [16][335/unk]	Per Sample Total Time 0.00446	Per Sample Data Time 0.00132	Per Sample DNN Time 0.00314	Train Loss 0.1491	
start validation
acc: 0.611400
AUC: 0.931446
Avg Precision: 0.415763
Avg Recall: 0.849600
d_prime: 2.102432
train_loss: 0.148882
valid_loss: 0.699060
validation finished
Epoch-16 lr: 3.556043928404302e-05
epoch 16 training time: 252.200
---------------
2023-06-07 17:35:51.821905
current #epochs=17, #steps=6256




Epoch: [17][44/unk]	Per Sample Total Time 0.00449	Per Sample Data Time 0.00135	Per Sample DNN Time 0.00313	Train Loss 0.1485	
Epoch: [17][144/unk]	Per Sample Total Time 0.00448	Per Sample Data Time 0.00133	Per Sample DNN Time 0.00315	Train Loss 0.1466	
Epoch: [17][244/unk]	Per Sample Total Time 0.00447	Per Sample Data Time 0.00133	Per Sample DNN Time 0.00315	Train Loss 0.1464	
Epoch: [17][344/unk]	Per Sample Total Time 0.00442	Per Sample Data Time 0.00128	Per Sample DNN Time 0.00314	Train Loss 0.1445	
start validation
acc: 0.613400
AUC: 0.931907
Avg Precision: 0.422195
Avg Recall: 0.845100
d_prime: 2.107384
train_loss: 0.144416
valid_loss: 0.698318
validation finished
Epoch-17 lr: 3.0226373391436563e-05
epoch 17 training time: 252.001
---------------
2023-06-07 17:40:03.823230
current #epochs=18, #steps=6647




Epoch: [18][53/unk]	Per Sample Total Time 0.00446	Per Sample Data Time 0.00129	Per Sample DNN Time 0.00316	Train Loss 0.1453	
Epoch: [18][153/unk]	Per Sample Total Time 0.00447	Per Sample Data Time 0.00132	Per Sample DNN Time 0.00315	Train Loss 0.1426	
Epoch: [18][253/unk]	Per Sample Total Time 0.00445	Per Sample Data Time 0.00131	Per Sample DNN Time 0.00314	Train Loss 0.1425	
Epoch: [18][353/unk]	Per Sample Total Time 0.00442	Per Sample Data Time 0.00127	Per Sample DNN Time 0.00315	Train Loss 0.1405	
start validation
acc: 0.616500
AUC: 0.932152
Avg Precision: 0.393802
Avg Recall: 0.847400
d_prime: 2.110023
train_loss: 0.140430
valid_loss: 0.697574
validation finished
Epoch-18 lr: 2.5692417382721078e-05
epoch 18 training time: 251.491
---------------
2023-06-07 17:44:15.314781
current #epochs=19, #steps=7038




Epoch: [19][62/unk]	Per Sample Total Time 0.00465	Per Sample Data Time 0.00151	Per Sample DNN Time 0.00314	Train Loss 0.1430	
Epoch: [19][162/unk]	Per Sample Total Time 0.00460	Per Sample Data Time 0.00146	Per Sample DNN Time 0.00313	Train Loss 0.1391	
Epoch: [19][262/unk]	Per Sample Total Time 0.00449	Per Sample Data Time 0.00136	Per Sample DNN Time 0.00314	Train Loss 0.1388	
Epoch: [19][362/unk]	Per Sample Total Time 0.00449	Per Sample Data Time 0.00135	Per Sample DNN Time 0.00314	Train Loss 0.1366	
start validation
acc: 0.618300
AUC: 0.932451
Avg Precision: 0.433254
Avg Recall: 0.831000
d_prime: 2.113254
train_loss: 0.136793
valid_loss: 0.697045
validation finished
Epoch-19 lr: 2.1838554775312915e-05
epoch 19 training time: 256.937
---------------
2023-06-07 17:48:32.251568
current #epochs=20, #steps=7429




Epoch: [20][71/unk]	Per Sample Total Time 0.00448	Per Sample Data Time 0.00135	Per Sample DNN Time 0.00314	Train Loss 0.1397	
Epoch: [20][171/unk]	Per Sample Total Time 0.00479	Per Sample Data Time 0.00165	Per Sample DNN Time 0.00314	Train Loss 0.1356	
Epoch: [20][271/unk]	Per Sample Total Time 0.00463	Per Sample Data Time 0.00149	Per Sample DNN Time 0.00314	Train Loss 0.1353	
Epoch: [20][371/unk]	Per Sample Total Time 0.00457	Per Sample Data Time 0.00143	Per Sample DNN Time 0.00314	Train Loss 0.1332	
start validation
acc: 0.616100
AUC: 0.932713
Avg Precision: 0.432903
Avg Recall: 0.830800
d_prime: 2.116094
train_loss: 0.133415
valid_loss: 0.696631
validation finished
Epoch-20 lr: 1.8562771559015977e-05
epoch 20 training time: 259.730
---------------
2023-06-07 17:52:51.981886
current #epochs=21, #steps=7820




Epoch: [21][80/unk]	Per Sample Total Time 0.00471	Per Sample Data Time 0.00155	Per Sample DNN Time 0.00316	Train Loss 0.1370	
Epoch: [21][180/unk]	Per Sample Total Time 0.00453	Per Sample Data Time 0.00138	Per Sample DNN Time 0.00315	Train Loss 0.1324	
Epoch: [21][280/unk]	Per Sample Total Time 0.00446	Per Sample Data Time 0.00132	Per Sample DNN Time 0.00314	Train Loss 0.1321	
Epoch: [21][380/unk]	Per Sample Total Time 0.00443	Per Sample Data Time 0.00128	Per Sample DNN Time 0.00314	Train Loss 0.1306	
start validation
acc: 0.620900
AUC: 0.933007
Avg Precision: 0.429911
Avg Recall: 0.836000
d_prime: 2.119291
train_loss: 0.130554
valid_loss: 0.696404
validation finished
Epoch-21 lr: 1.577835582516358e-05
epoch 21 training time: 261.488
---------------
2023-06-07 17:57:13.469795
current #epochs=22, #steps=8211




Epoch: [22][89/unk]	Per Sample Total Time 0.00474	Per Sample Data Time 0.00160	Per Sample DNN Time 0.00314	Train Loss 0.1337	
Epoch: [22][189/unk]	Per Sample Total Time 0.00452	Per Sample Data Time 0.00138	Per Sample DNN Time 0.00314	Train Loss 0.1294	
Epoch: [22][289/unk]	Per Sample Total Time 0.00449	Per Sample Data Time 0.00135	Per Sample DNN Time 0.00314	Train Loss 0.1292	
Epoch: [22][389/unk]	Per Sample Total Time 0.00442	Per Sample Data Time 0.00129	Per Sample DNN Time 0.00314	Train Loss 0.1278	
start validation
acc: 0.625900
AUC: 0.933636
Avg Precision: 0.405979
Avg Recall: 0.842000
d_prime: 2.126170
train_loss: 0.127798
valid_loss: 0.696135
validation finished
Epoch-22 lr: 1.3411602451389044e-05
epoch 22 training time: 258.831
---------------
2023-06-07 18:01:32.301134
current #epochs=23, #steps=8602




Epoch: [23][98/unk]	Per Sample Total Time 0.00460	Per Sample Data Time 0.00146	Per Sample DNN Time 0.00314	Train Loss 0.1311	
Epoch: [23][198/unk]	Per Sample Total Time 0.00447	Per Sample Data Time 0.00133	Per Sample DNN Time 0.00314	Train Loss 0.1271	
Epoch: [23][298/unk]	Per Sample Total Time 0.00445	Per Sample Data Time 0.00132	Per Sample DNN Time 0.00314	Train Loss 0.1269	
start validation
acc: 0.627300
AUC: 0.934417
Avg Precision: 0.325709
Avg Recall: 0.886900
d_prime: 2.134783
train_loss: 0.125381
valid_loss: 0.695727
validation finished
Epoch-23 lr: 1.1399862083680687e-05
epoch 23 training time: 256.039
---------------
2023-06-07 18:05:48.340356
current #epochs=24, #steps=8993




Epoch: [24][7/unk]	Per Sample Total Time 0.00672	Per Sample Data Time 0.00358	Per Sample DNN Time 0.00314	Train Loss 0.1329	
Epoch: [24][107/unk]	Per Sample Total Time 0.00472	Per Sample Data Time 0.00158	Per Sample DNN Time 0.00314	Train Loss 0.1278	
Epoch: [24][207/unk]	Per Sample Total Time 0.00453	Per Sample Data Time 0.00139	Per Sample DNN Time 0.00314	Train Loss 0.1252	
Epoch: [24][307/unk]	Per Sample Total Time 0.00453	Per Sample Data Time 0.00139	Per Sample DNN Time 0.00314	Train Loss 0.1245	
start validation
acc: 0.626400
AUC: 0.934417
Avg Precision: 0.329538
Avg Recall: 0.889500
d_prime: 2.134784
train_loss: 0.123200
valid_loss: 0.695838
validation finished
Epoch-24 lr: 9.689882771128584e-06
epoch 24 training time: 257.182
---------------
2023-06-07 18:10:05.522208
current #epochs=25, #steps=9384




Epoch: [25][16/unk]	Per Sample Total Time 0.00543	Per Sample Data Time 0.00230	Per Sample DNN Time 0.00313	Train Loss 0.1302	
Epoch: [25][116/unk]	Per Sample Total Time 0.00452	Per Sample Data Time 0.00139	Per Sample DNN Time 0.00314	Train Loss 0.1254	
Epoch: [25][216/unk]	Per Sample Total Time 0.00450	Per Sample Data Time 0.00137	Per Sample DNN Time 0.00314	Train Loss 0.1230	
Epoch: [25][316/unk]	Per Sample Total Time 0.00452	Per Sample Data Time 0.00138	Per Sample DNN Time 0.00314	Train Loss 0.1221	
start validation
acc: 0.628900
AUC: 0.934704
Avg Precision: 0.331861
Avg Recall: 0.892500
d_prime: 2.137972
train_loss: 0.121254
valid_loss: 0.695843
validation finished
Epoch-25 lr: 8.236400355459297e-06
epoch 25 training time: 257.098
---------------
2023-06-07 18:14:22.620442
current #epochs=26, #steps=9775




Epoch: [26][25/unk]	Per Sample Total Time 0.00497	Per Sample Data Time 0.00183	Per Sample DNN Time 0.00314	Train Loss 0.1281	
Epoch: [26][125/unk]	Per Sample Total Time 0.00451	Per Sample Data Time 0.00137	Per Sample DNN Time 0.00315	Train Loss 0.1234	
Epoch: [26][225/unk]	Per Sample Total Time 0.00439	Per Sample Data Time 0.00125	Per Sample DNN Time 0.00315	Train Loss 0.1213	
Epoch: [26][325/unk]	Per Sample Total Time 0.00446	Per Sample Data Time 0.00131	Per Sample DNN Time 0.00315	Train Loss 0.1200	
start validation
acc: 0.630600
AUC: 0.934921
Avg Precision: 0.336842
Avg Recall: 0.889100
d_prime: 2.140382
train_loss: 0.119476
valid_loss: 0.695642
validation finished
Epoch-26 lr: 7.0009403021404025e-06
epoch 26 training time: 257.095
---------------
2023-06-07 18:18:39.715247
current #epochs=27, #steps=10166




Epoch: [27][34/unk]	Per Sample Total Time 0.00499	Per Sample Data Time 0.00185	Per Sample DNN Time 0.00314	Train Loss 0.1246	
Epoch: [27][134/unk]	Per Sample Total Time 0.00456	Per Sample Data Time 0.00142	Per Sample DNN Time 0.00314	Train Loss 0.1212	
Epoch: [27][234/unk]	Per Sample Total Time 0.00458	Per Sample Data Time 0.00144	Per Sample DNN Time 0.00314	Train Loss 0.1195	
Epoch: [27][334/unk]	Per Sample Total Time 0.00450	Per Sample Data Time 0.00137	Per Sample DNN Time 0.00314	Train Loss 0.1180	
start validation
acc: 0.631600
AUC: 0.935198
Avg Precision: 0.310545
Avg Recall: 0.901700
d_prime: 2.143469
train_loss: 0.117889
valid_loss: 0.695223
validation finished
Epoch-27 lr: 5.950799256819342e-06
epoch 27 training time: 258.854
---------------
2023-06-07 18:22:58.569181
current #epochs=28, #steps=10557




Epoch: [28][43/unk]	Per Sample Total Time 0.00465	Per Sample Data Time 0.00150	Per Sample DNN Time 0.00314	Train Loss 0.1212	
Epoch: [28][143/unk]	Per Sample Total Time 0.00452	Per Sample Data Time 0.00138	Per Sample DNN Time 0.00314	Train Loss 0.1194	
Epoch: [28][243/unk]	Per Sample Total Time 0.00454	Per Sample Data Time 0.00140	Per Sample DNN Time 0.00315	Train Loss 0.1183	
Epoch: [28][343/unk]	Per Sample Total Time 0.00450	Per Sample Data Time 0.00136	Per Sample DNN Time 0.00314	Train Loss 0.1165	
start validation
acc: 0.635800
AUC: 0.935486
Avg Precision: 0.313579
Avg Recall: 0.901900
d_prime: 2.146698
train_loss: 0.116434
valid_loss: 0.694807
validation finished
Epoch-28 lr: 5.058179368296441e-06
epoch 28 training time: 259.069
---------------
2023-06-07 18:27:17.638376
current #epochs=29, #steps=10948




Epoch: [29][52/unk]	Per Sample Total Time 0.00456	Per Sample Data Time 0.00142	Per Sample DNN Time 0.00314	Train Loss 0.1210	
Epoch: [29][152/unk]	Per Sample Total Time 0.00451	Per Sample Data Time 0.00137	Per Sample DNN Time 0.00314	Train Loss 0.1177	
Epoch: [29][252/unk]	Per Sample Total Time 0.00450	Per Sample Data Time 0.00136	Per Sample DNN Time 0.00314	Train Loss 0.1169	
Epoch: [29][352/unk]	Per Sample Total Time 0.00448	Per Sample Data Time 0.00134	Per Sample DNN Time 0.00314	Train Loss 0.1150	
start validation
acc: 0.639100
AUC: 0.935497
Avg Precision: 0.319081
Avg Recall: 0.896000
d_prime: 2.146828
train_loss: 0.115047
valid_loss: 0.694355
validation finished
Epoch-29 lr: 4.299452463051975e-06
epoch 29 training time: 257.276
---------------
2023-06-07 18:31:34.914631
current #epochs=30, #steps=11339




Epoch: [30][61/unk]	Per Sample Total Time 0.00448	Per Sample Data Time 0.00134	Per Sample DNN Time 0.00314	Train Loss 0.1213	
Epoch: [30][161/unk]	Per Sample Total Time 0.00464	Per Sample Data Time 0.00149	Per Sample DNN Time 0.00315	Train Loss 0.1167	
Epoch: [30][261/unk]	Per Sample Total Time 0.00456	Per Sample Data Time 0.00141	Per Sample DNN Time 0.00314	Train Loss 0.1159	
Epoch: [30][361/unk]	Per Sample Total Time 0.00452	Per Sample Data Time 0.00138	Per Sample DNN Time 0.00314	Train Loss 0.1137	
start validation
acc: 0.638500
AUC: 0.935343
Avg Precision: 0.322927
Avg Recall: 0.891900
d_prime: 2.145092
train_loss: 0.113858
valid_loss: 0.694129
validation finished
Epoch-30 lr: 3.6545345935941787e-06
epoch 30 training time: 259.152


## CIFAR10 pretrained on audio model

In [21]:
class Arguments():

  model='ast'
  dataset='speechcommands'
  imagenetpretrain=True
  audiosetpretrain=False

  bal=None
  lr=2.5e-4

  n_epochs=20
  freqm=48
  timem=48
  mixup=0.6
  batch_size=32
  fstride=10
  tstride=10
  dataset_mean=-6.845978
  dataset_std=5.5654526
  audio_length = 128
  noise=True

  num_workers = 32
  exp_dir = '/content/drive/MyDrive/Thesis/resout_qast_with_cifar10'
  optimizer = 'adam'
  metrics='acc'
  loss='BCE'              

  lrscheduler_start=5
  lrscheduler_step=1
  lrscheduler_decay=0.85

  warmup = False
  wa = False
  wa_start = 1
  wa_end = 5

  n_print_steps = 100
  n_class = 35
  lr_patience = 2
  save_model = True
args_usepretrain = Arguments()

In [22]:
data_train = '/content/drive/MyDrive/Thesis/datafiles/speechcommand_train_data.json'
data_val_path ='/content/drive/MyDrive/Thesis/datafiles/speechcommand_valid_data.json'
data_eval_path ='/content/drive/MyDrive/Thesis/datafiles/speechcommand_eval_data.json'


label_csv = '/content/drive/MyDrive/Thesis/ast/egs/speechcommands/data/speechcommands_class_labels_indices.csv'


audio_conf = {'num_mel_bins': 128, 'target_length': args_usepretrain.audio_length, 'freqm': args_usepretrain.freqm, 'timem': args_usepretrain.timem, 'mixup': args.mixup, 'dataset': args.dataset, 'mode':'train', 'mean':args.dataset_mean, 'std':args.dataset_std,
                  'noise':args_usepretrain.noise}

val_audio_conf = {'num_mel_bins': 128, 'target_length': args.audio_length, 'freqm': 0, 'timem': 0, 'mixup': 0, 'dataset': args.dataset, 'mode':'validation', 'mean':args.dataset_mean, 'std':args.dataset_std, 'noise':False}
eval_audio_conf = {'num_mel_bins': 128, 'target_length': args.audio_length, 'freqm': 0, 'timem': 0, 'mixup': 0, 'dataset': args.dataset, 'mode':'evaluation', 'mean':args.dataset_mean, 'std':args.dataset_std, 'noise':False}

train_loader = torch.utils.data.DataLoader(
            dataloader.AudiosetDataset(data_train, label_csv=label_csv, audio_conf=audio_conf),
            batch_size=args_usepretrain.batch_size, shuffle=True, num_workers=args_usepretrain.num_workers, pin_memory=True)

eval_loader = torch.utils.data.DataLoader(
        dataloader.AudiosetDataset(data_eval_path, label_csv=label_csv, audio_conf=val_audio_conf),
        batch_size=args_usepretrain.batch_size*2, shuffle=False, num_workers=args_usepretrain.num_workers, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
        dataloader.AudiosetDataset(data_val_path, label_csv=label_csv, audio_conf=val_audio_conf),
        batch_size=args_usepretrain.batch_size*2, shuffle=False, num_workers=args_usepretrain.num_workers, pin_memory=True)

---------------the train dataloader---------------
now using following mask: 48 freq, 48 time
now using mix-up with rate 0.600000
now process speechcommands
use dataset mean -6.846 and std 5.565 to normalize the input.
now use noise augmentation
number of classes is 35




---------------the validation dataloader---------------
now using following mask: 0 freq, 0 time
now using mix-up with rate 0.000000
now process speechcommands
use dataset mean -6.846 and std 5.565 to normalize the input.
number of classes is 35
---------------the validation dataloader---------------
now using following mask: 0 freq, 0 time
now using mix-up with rate 0.000000
now process speechcommands
use dataset mean -6.846 and std 5.565 to normalize the input.
number of classes is 35


In [23]:
import torch
best_pretrained_model = '/content/drive/MyDrive/Thesis/pretrain_qvit_CIFAR10/models/best_audio_model.pth'
# ast_CIFAR_pretrained.load_state_dict()

In [25]:
from collections import OrderedDict
"""
loaded model has keys with "module." like module.fc1
to reuse it, need to rewrite it removing module
"""
# def load_new_states()
new_state_dict = OrderedDict()
state_dict = torch.load(best_pretrained_model)
for k, v in state_dict.items():
    name = ".".join(k.split(".")[1:]) # remove module. at the beginning
    new_state_dict[name] = v 

In [35]:
# ast_CIFAR_pretrained = QVIT(label_dim=10)
ast_CIFAR_pretrained = QVIT(label_dim=args_pretraining.n_classes, img_size = args_pretraining.img_size)
ast_CIFAR_pretrained.load_state_dict(new_state_dict)

<All keys matched successfully>

In [36]:
class ASTModel_with_pretraining(nn.Module):
  def __init__(self, label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=False, model_size='base384', verbose=True):
    super(ASTModel_with_pretraining, self).__init__()
    # automatcially get the intermediate shape
    self.original_embedding_dim = 768
    num_heads = 12
    mlp_ratio = 4.
    qkv_bias = True
    qk_norm = False
    drop_rate = 0.
    attn_drop_rate = 0.
    depth = 12


    f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
    num_patches = f_dim * t_dim

    self.patch_embed = PatchEmbed()
    self.cls_token = nn.Parameter(torch.zeros(1, 1, self.original_embedding_dim))
    self.dist_token = nn.Parameter(torch.zeros(1, 1, self.original_embedding_dim))
    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.original_embedding_dim))
    trunc_normal_(self.pos_embed, std=.02)
    self.pos_drop = nn.Dropout(p=0.)

    # TODO pretrained or sinusoidal
    if imagenet_pretrain:
      self.blocks = ast_CIFAR_pretrained.blocks
    else:
      self.blocks = nn.Sequential(*[
              Block(
                  self.original_embedding_dim,
                  num_heads,
                  mlp_ratio = mlp_ratio,
                  qkv_bias = qkv_bias,
                  qk_norm = qk_norm,
                  drop = drop_rate,
                  attn_drop = attn_drop_rate,
              )
              for i in range(depth)])
    self.norm =  Norm(self.original_embedding_dim)

    # Classifier Head
    self.fc_norm = Norm(self.original_embedding_dim) 
    self.head = nn.Linear(self.original_embedding_dim, label_dim) if label_dim > 0 else nn.Identity()
  
  @autocast()
  def forward(self, x):
    x = x.unsqueeze(1)
    x = x.transpose(2, 3)
    B = x.shape[0]
   
    x = self.patch_embed(x)
    
    cls_tokens = self.cls_token.expand(B, -1, -1)
    dist_token = self.dist_token.expand(B, -1, -1)

    x = torch.cat((cls_tokens, dist_token, x), dim=1)
    x = x + self.pos_embed
    x = self.pos_drop(x)

    x = self.blocks(x)
    x = self.norm(x)

    x = (x[:, 0] + x[:, 1]) / 2
    x = self.fc_norm(x)
    x = self.head(x)
    return x
     

  def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=1024):
    test_input = torch.randn(1, 1, input_fdim, input_tdim)
    test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
    test_out = test_proj(test_input)
    f_dim = test_out.shape[2]
    t_dim = test_out.shape[3]
    return f_dim, t_dim


In [37]:
new_audio_model = ASTModel_with_pretraining(label_dim=args.n_class, fstride=args_usepretrain.fstride, tstride=args_usepretrain.tstride, input_fdim=128,
                                input_tdim=args_usepretrain.audio_length, imagenet_pretrain=True,
                                audioset_pretrain=args_usepretrain.audiosetpretrain, model_size='base384')
new_audio_model.to(device)

ASTModel_with_pretraining(
  (patch_embed): PatchEmbed(
    (projq): QuaternionConv(in_channels=1, out_channels=192, bias=True, kernel_size=(16, 16), stride=10, padding=0, init_criterion=glorot, weight_init=quaternion, seed=808, rotation=False, q_format=True, operation=convolution2d)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): Norm()
      (attn): Attention(
        (qkv): QuaternionLinearAutograd(in_features=192, out_features=576, bias=True, init_criterion=glorot, weight_init=quaternion, rotation=False, seed=724)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): QuaternionLinearAutograd(in_features=192, out_features=192, bias=True, init_criterion=glorot, weight_init=quaternion, rotation=False, seed=617)
        (proj_drop): Dropout(p=0.1, inplace=False)
      )
      (norm2): Norm()
      (mlp): Mlp(
        (fc1): QuaternionLinearAutograd(in_features=1

In [38]:
prepare_result_saving(args_usepretrain)

Folders already exists
Target.csv already exists


In [None]:
print('Now starting training for {:d} epochs'.format(args_usepretrain.n_epochs))

train(new_audio_model, train_loader, val_loader, args_usepretrain)

Now starting training for 20 epochs
running on cuda
Total parameter number is : 21.665 million
Total trainable parameter number is : 21.665 million
now training with speechcommands, main metrics: acc, loss function: BCEWithLogitsLoss(), learning rate scheduler: <torch.optim.lr_scheduler.MultiStepLR object at 0x7f77935561d0>
The learning rate scheduler starts at 5 epoch with decay rate of 0.850 every 1 epochs
current #steps=0, #epochs=1
start training...
---------------
2023-06-12 12:48:19.350827
current #epochs=1, #steps=0
Epoch: [1][100/100000]	Per Sample Total Time 0.06609	Per Sample Data Time 0.03517	Per Sample DNN Time 0.03092	Train Loss 0.1747	
Epoch: [1][200/100000]	Per Sample Total Time 0.04741	Per Sample Data Time 0.01768	Per Sample DNN Time 0.02973	Train Loss 0.1518	
Epoch: [1][300/100000]	Per Sample Total Time 0.04114	Per Sample Data Time 0.01181	Per Sample DNN Time 0.02933	Train Loss 0.1441	
Epoch: [1][400/100000]	Per Sample Total Time 0.03800	Per Sample Data Time 0.00887	Pe



Epoch: [2][52/100000]	Per Sample Total Time 0.03078	Per Sample Data Time 0.00198	Per Sample DNN Time 0.02880	Train Loss 0.1288	
Epoch: [2][152/100000]	Per Sample Total Time 0.02930	Per Sample Data Time 0.00070	Per Sample DNN Time 0.02861	Train Loss 0.1287	
Epoch: [2][252/100000]	Per Sample Total Time 0.02898	Per Sample Data Time 0.00042	Per Sample DNN Time 0.02855	Train Loss 0.1285	
Epoch: [2][352/100000]	Per Sample Total Time 0.02884	Per Sample Data Time 0.00031	Per Sample DNN Time 0.02853	Train Loss 0.1285	
Epoch: [2][452/100000]	Per Sample Total Time 0.02876	Per Sample Data Time 0.00024	Per Sample DNN Time 0.02852	Train Loss 0.1285	
Epoch: [2][552/100000]	Per Sample Total Time 0.02872	Per Sample Data Time 0.00020	Per Sample DNN Time 0.02852	Train Loss 0.1285	
Epoch: [2][652/100000]	Per Sample Total Time 0.02869	Per Sample Data Time 0.00017	Per Sample DNN Time 0.02851	Train Loss 0.1285	
Epoch: [2][752/100000]	Per Sample Total Time 0.02866	Per Sample Data Time 0.00015	Per Sample DNN T



Epoch: [3][4/100000]	Per Sample Total Time 0.05042	Per Sample Data Time 0.02123	Per Sample DNN Time 0.02919	Train Loss 0.1274	
Epoch: [3][104/100000]	Per Sample Total Time 0.02968	Per Sample Data Time 0.00102	Per Sample DNN Time 0.02865	Train Loss 0.1282	
Epoch: [3][204/100000]	Per Sample Total Time 0.02911	Per Sample Data Time 0.00053	Per Sample DNN Time 0.02858	Train Loss 0.1282	
Epoch: [3][304/100000]	Per Sample Total Time 0.02893	Per Sample Data Time 0.00036	Per Sample DNN Time 0.02857	Train Loss 0.1283	
Epoch: [3][404/100000]	Per Sample Total Time 0.02883	Per Sample Data Time 0.00027	Per Sample DNN Time 0.02856	Train Loss 0.1281	
Epoch: [3][504/100000]	Per Sample Total Time 0.02877	Per Sample Data Time 0.00022	Per Sample DNN Time 0.02855	Train Loss 0.1281	
Epoch: [3][604/100000]	Per Sample Total Time 0.02873	Per Sample Data Time 0.00019	Per Sample DNN Time 0.02855	Train Loss 0.1281	
Epoch: [3][704/100000]	Per Sample Total Time 0.02871	Per Sample Data Time 0.00016	Per Sample DNN Ti



Epoch: [4][56/100000]	Per Sample Total Time 0.03084	Per Sample Data Time 0.00210	Per Sample DNN Time 0.02874	Train Loss 0.1258	
Epoch: [4][156/100000]	Per Sample Total Time 0.02941	Per Sample Data Time 0.00077	Per Sample DNN Time 0.02864	Train Loss 0.1253	
Epoch: [4][256/100000]	Per Sample Total Time 0.02912	Per Sample Data Time 0.00048	Per Sample DNN Time 0.02865	Train Loss 0.1250	
Epoch: [4][356/100000]	Per Sample Total Time 0.02898	Per Sample Data Time 0.00035	Per Sample DNN Time 0.02864	Train Loss 0.1247	
Epoch: [4][456/100000]	Per Sample Total Time 0.02890	Per Sample Data Time 0.00027	Per Sample DNN Time 0.02863	Train Loss 0.1245	
Epoch: [4][556/100000]	Per Sample Total Time 0.02886	Per Sample Data Time 0.00023	Per Sample DNN Time 0.02863	Train Loss 0.1243	
Epoch: [4][656/100000]	Per Sample Total Time 0.02882	Per Sample Data Time 0.00019	Per Sample DNN Time 0.02863	Train Loss 0.1242	
Epoch: [4][756/100000]	Per Sample Total Time 0.02880	Per Sample Data Time 0.00017	Per Sample DNN T



Epoch: [5][8/100000]	Per Sample Total Time 0.04104	Per Sample Data Time 0.01199	Per Sample DNN Time 0.02905	Train Loss 0.1166	
Epoch: [5][108/100000]	Per Sample Total Time 0.02972	Per Sample Data Time 0.00100	Per Sample DNN Time 0.02872	Train Loss 0.1155	
Epoch: [5][208/100000]	Per Sample Total Time 0.02925	Per Sample Data Time 0.00053	Per Sample DNN Time 0.02872	Train Loss 0.1155	
Epoch: [5][308/100000]	Per Sample Total Time 0.02906	Per Sample Data Time 0.00036	Per Sample DNN Time 0.02869	Train Loss 0.1152	
Epoch: [5][408/100000]	Per Sample Total Time 0.02896	Per Sample Data Time 0.00028	Per Sample DNN Time 0.02869	Train Loss 0.1148	
Epoch: [5][508/100000]	Per Sample Total Time 0.02891	Per Sample Data Time 0.00022	Per Sample DNN Time 0.02868	Train Loss 0.1147	
Epoch: [5][608/100000]	Per Sample Total Time 0.02887	Per Sample Data Time 0.00019	Per Sample DNN Time 0.02868	Train Loss 0.1144	
Epoch: [5][708/100000]	Per Sample Total Time 0.02885	Per Sample Data Time 0.00016	Per Sample DNN Ti

In [None]:
ast_CIFAR_pretrained.pos_embed.shape

torch.Size([1, 6, 768])

## Pretraining QAST on TinyImageNet

In [None]:
# https://huggingface.co/datasets/Maysee/tiny-imagenet

args_pre_tiny_imagenet = Arguments()

args_pre_tiny_imagenet.exp_dir = '/content/drive/MyDrive/Thesis/pretrain_qvit_tiny_imagenet'
args_pre_tiny_imagenet.n_classes = 200
args_pre_tiny_imagenet.img_size = 64


args_pre_tiny_imagenet.batch_size = 128
args_pre_tiny_imagenet.n_epochs = 20
args_pre_tiny_imagenet.lr = 2.5e-5
args_pre_tiny_imagenet.warmup = True

args_pre_tiny_imagenet.n_print_steps = 100
args_pre_tiny_imagenet.num_workers = 8

args_pre_tiny_imagenet.wa = False
# args_pretraining.lrscheduler_start=0
# args_pretraining.lrscheduler_step=1
# args_pretraining.lrscheduler_decay=0.85

def collate_fn(examples):
    images = []
    labels = []
    convert_tensor = transforms.ToTensor()

    for example in examples:
        images.append(convert_tensor(example["image"]))
        labels.append(example["label"])
        
    pixel_values = torch.stack(images)
    labels = torch.tensor(labels)

    b_size = labels.shape[0]
    n_classes = args_pre_tiny_imagenet.n_classes
    y = torch.zeros(b_size, n_classes)
    y[range(y.shape[0]), labels]=1
    return (pixel_values, y)

In [None]:
access_token = "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"

In [None]:
from datasets import load_dataset, Image
dset = load_dataset('Maysee/tiny-imagenet', split='train', streaming=True, use_auth_token=access_token).cast_column("image", Image())
dset_eval = load_dataset('Maysee/tiny-imagenet', split='valid', streaming=True, use_auth_token=access_token).cast_column("image", Image())
# dset_test = load_dataset('Maysee/tiny-imagenet', split='test', streaming=True, use_auth_token=access_token).cast_column("image", Image())


Downloading metadata:   0%|          | 0.00/3.52k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/3.90k [00:00<?, ?B/s]

In [None]:
dset = dset.map(transforms_, batched=True)
dset_eval = dset_eval.map(transforms_, batched=True)
# dset_test = dset_test.map(transforms_, batched=True)

dset_iter = dset.with_format("torch")
dset_eval_iter = dset_eval.with_format("torch")
# dset_test_iter = dset_test.with_format("torch")

In [None]:
train_loader_tiny_imagenet = DataLoader(dset_iter, collate_fn=collate_fn, batch_size = args_pre_tiny_imagenet.batch_size, pin_memory=True)
eval_loader_tiny_imagenet = DataLoader(dset_eval_iter, collate_fn=collate_fn, batch_size=args_pre_tiny_imagenet.batch_size, pin_memory=True)
# test_loader_tiny_imagenet = DataLoader(dset_test_iter, collate_fn=collate_fn, batch_size=args_pre_tiny_imagenet.batch_size, pin_memory=True)


In [None]:
ast_tiny_imagenet = QVIT(label_dim=args_pre_tiny_imagenet.n_classes, img_size = args_pre_tiny_imagenet.img_size)

In [None]:
ast_tiny_imagenet.to(device)

QVIT(
  (patch_embed): PatchEmbedding(
    (proj): Conv2d(4, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): Norm()
      (attn): Attention(
        (qkv): QuaternionLinearAutograd(in_features=192, out_features=576, bias=True, init_criterion=glorot, weight_init=quaternion, rotation=False, seed=1057)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): QuaternionLinearAutograd(in_features=192, out_features=192, bias=True, init_criterion=glorot, weight_init=quaternion, rotation=False, seed=40)
        (proj_drop): Dropout(p=0.1, inplace=False)
      )
      (norm2): Norm()
      (mlp): Mlp(
        (fc1): QuaternionLinearAutograd(in_features=192, out_features=768, bias=True, init_criterion=glorot, weight_init=quaternion, rotation=False, seed=501)
        (drop1): Dropout(p=0.1, inplace=False)
    

In [None]:
prepare_result_saving(args_pre_tiny_imagenet, val_data = eval_loader_tiny_imagenet)

Folders already exists
Target.csv already exists


In [None]:
print('Now starting training for {:d} epochs'.format(args_pre_tiny_imagenet.n_epochs))

train(ast_tiny_imagenet, train_loader_tiny_imagenet, eval_loader_tiny_imagenet, args_pre_tiny_imagenet)

Now starting training for 20 epochs
running on cuda
Total parameter number is : 22.283 million
Total trainable parameter number is : 22.283 million
now training with speechcommands, main metrics: acc, loss function: BCEWithLogitsLoss(), learning rate scheduler: <torch.optim.lr_scheduler.MultiStepLR object at 0x7f4b584f30a0>
The learning rate scheduler starts at 5 epoch with decay rate of 0.850 every 1 epochs
current #steps=0, #epochs=1
start training...
---------------
2023-06-05 12:41:20.555890
current #epochs=1, #steps=0
warm-up learning rate is 0.000000
warm-up learning rate is 0.000001
warm-up learning rate is 0.000003
Epoch: [1][100/100000]	Per Sample Total Time 0.00372	Per Sample Data Time 0.00090	Per Sample DNN Time 0.00282	Train Loss 0.6722	
warm-up learning rate is 0.000004
warm-up learning rate is 0.000005
Epoch: [1][200/100000]	Per Sample Total Time 0.00328	Per Sample Data Time 0.00075	Per Sample DNN Time 0.00253	Train Loss 0.6047	
warm-up learning rate is 0.000006
warm-up l

## ImageNet

In [None]:
from datasets import load_dataset, Image
dset = load_dataset('imagenet-1k', split='train', streaming=True, use_auth_token=True).cast_column("image", Image())
dset_eval = load_dataset('imagenet-1k', split='validation', streaming=True, use_auth_token=True).cast_column("image", Image())
dset_test = load_dataset('imagenet-1k', split='test', streaming=True, use_auth_token=True).cast_column("image", Image())

dset = dset.map(transforms_, batched=True)
dset_eval = dset_eval.map(transforms_, batched=True)
dset_test = dset_test.map(transforms_, batched=True)

dset_iter = dset.with_format("torch")
dset_eval_iter = dset_eval.with_format("torch")
dset_test_iter = dset_test.with_format("torch")

train_loader_imagenet = DataLoader(dset_iter, collate_fn=collate_fn, batch_size = args_pretraining.batch_size, pin_memory=True)
eval_loader_imagenet = DataLoader(dset_eval_iter, collate_fn=collate_fn, batch_size=args_pretraining.batch_size, pin_memory=True)
test_loader_imagenet = DataLoader(dset_test_iter, collate_fn=collate_fn, batch_size=args_pretraining.batch_size, pin_memory=True)


In [None]:
ast_tiny_imagenet = QVIT(label_dim=args_pretraining.n_classes, img_size = 32)