In [47]:
import torch
import torch
from torch.utils.data import DataLoader, Dataset
import os

from optuna_utils.dataset_pytorch import ASTDataset
from optuna_utils.config import CFG
from optuna_utils.dataset import filter_data, upsample_data
from sklearn.model_selection import StratifiedKFold
from transformers import ASTConfig
import pandas as pd
from transformers import ASTFeatureExtractor
from transformers import ASTPreTrainedModel, ASTModel, AutoConfig, ASTConfig
import torch.nn as nn
from tqdm import tqdm

device = torch.device('cpu' if torch.cuda.is_available() else 'cpu') 

class DenseLayer(nn.Module):
    def __init__(self, config, output_dim):
        super().__init__()
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dense = nn.Linear(config.hidden_size, output_dim)
    
    def forward(self, hidden_state):
        hidden_state = self.layernorm(hidden_state)
        hidden_state = self.dense(hidden_state)
        return hidden_state
    
class ASTagModel(ASTPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.audio_spectrogram_transformer = ASTModel(config)

        for pname, p in self.named_parameters():
            if pname.find('layer.') >= 0:
                layer = int(pname.split('.')[3])
                if layer<=CFG.ast_fix_layer:
                    p.requires_grad = False
            else:
                p.requires_grad = False
            
        self.linear = DenseLayer(config, CFG.num_classes)
        self.n_class = CFG.num_classes
    
    def forward(self, input_values):
        outputs = self.audio_spectrogram_transformer(input_values)
        hidden_states = outputs.last_hidden_state
        pool_output = torch.mean(hidden_states, dim=1)
        # pool_output = outputs.pooler_output
        logits = self.linear(pool_output)
        return nn.Sigmoid()(logits) if CFG.loss=='BCE' else logits

In [4]:


GCS_PATH = CFG.base_path

df = pd.read_csv(f'{CFG.base_path}/train_metadata.csv')
df['filepath'] = GCS_PATH + '/train_audio/' + df.filename
df['target'] = df.primary_label.map(CFG.name2label)

f_df = filter_data(df, thr=5)
# f_df.cv.value_counts().plot.bar(legend=True)
up_df = upsample_data(df, thr=50)

CFG.class_weights = up_df.primary_label.value_counts()[:].to_numpy()

# Initialize the StratifiedKFold object with 5 splits and shuffle the data
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=CFG.seed)

# Reset the index of the dataframe
df = df.reset_index(drop=True)

# Create a new column in the dataframe to store the fold number for each row
df["fold"] = -1

# Iterate over the folds and assign the corresponding fold number to each row in the dataframe
for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['primary_label'])):
    df.loc[val_idx, 'fold'] = fold



In [None]:
ckpt = torch.load('./ast_int8_state.pth', map_location=device)

config = ASTConfig() 
model = ASTagModel(config)
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_fp32_prepared = torch.quantization.prepare(model)
model_int8 = torch.quantization.convert(model_fp32_prepared)

model_int8.load_state_dict(ckpt)
model = model_int8
model.eval()
# model = model.load_state_dict(ckpt)

# model = torch.load('/home/plathzheng/program/ml_tutorial/ml_project_birdclef23/ast_int8.pth', map_location=device)
# torch.save(model.state_dict(), './ast_int8_state.pth')

In [None]:
def evaluate(model):
    for batch in tqdm(loader_eval):
        audio, label = batch
        audio = audio.to(device)

        prob = model(audio)

evaluate(model)

In [21]:
import torch
from torch.quantization import get_default_qat_qconfig, get_default_qconfig, QConfigDynamic, QConfig
import torch.quantization.quantize_fx as quantize_fx
from torch.quantization.quantize_fx import prepare_fx
import copy

In [None]:
config = ASTConfig() 
float_model = ASTagModel(config)
dataset_eval = ASTDataset(df, fold=4, mode='eval')
loader_eval = DataLoader(dataset_eval, batch_size=10, shuffle=False, num_workers=0 if CFG.debug else 10)
import torch
from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import prepare_fx

float_model.eval()

def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)
            break

qconfig_dict = {"": torch.quantization.default_dynamic_qconfig}
prepared_model = prepare_fx(float_model, qconfig_dict)
# Run calibration
calibrate(prepared_model, loader_eval)

model_int8 = quantize_fx.convert_fx(prepared_model)

In [30]:
state_dict = torch.load('experiments/ast_9layer/trial_1/ast.pth', map_location=device)
config = ASTConfig() 
model = ASTagModel(config=config)
model = model.to(device)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [37]:
import numpy as np
import torch.nn.functional as F
import sklearn

def measurement(y_true, y_pred, padding_factor=5):
    if not CFG.loss=='BCE':
        y_true = F.one_hot(torch.from_numpy(y_true), num_classes=CFG.num_classes).numpy()
    # y_true = y_true.numpy()
    num_classes = y_true.shape[1]
    pad_rows = np.array([[1]*num_classes]*padding_factor)
    y_true = np.concatenate([y_true, pad_rows])
    y_pred = np.concatenate([y_pred, pad_rows])
    score = sklearn.metrics.average_precision_score(y_true, y_pred, average='macro',)
    roc_aucs = sklearn.metrics.roc_auc_score(y_true, y_pred, average='macro')
    return score, roc_aucs

def evaluate(model, loader):
    label_stack = []
    pred_stack = torch.randn(size=(1, CFG.num_classes)).to(device)
    for batch in tqdm(loader):
        audio, label = batch
        audio = audio.to(device)
        label = label.to(device)
        with torch.no_grad():
            prob = model(audio)
        label_stack += label.cpu().numpy().tolist()
        pred_stack = torch.cat([pred_stack, prob], dim=0)
    pred_stack = pred_stack[1:]
    pred_stack = pred_stack.detach().cpu().numpy()
    label_stack = np.array(label_stack)
    acc, auc = measurement(label_stack, pred_stack)
    return acc, auc

In [31]:
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)
print(quantized_model)

ASTagModel(
  (audio_spectrogram_transformer): ASTModel(
    (embeddings): ASTEmbeddings(
      (patch_embeddings): ASTPatchEmbeddings(
        (projection): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ASTEncoder(
      (layer): ModuleList(
        (0): ASTLayer(
          (attention): ASTAttention(
            (attention): ASTSelfAttention(
              (query): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (key): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (value): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ASTSelfOutput(
              (dense): DynamicQuantizedLinear(in_features=768, out_

In [32]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

print_size_of_model(model)
print_size_of_model(quantized_model)

Size (MB): 345.656363
Size (MB): 90.305543


In [None]:
import time
dataset_eval = ASTDataset(df, fold=4, mode='eval')
loader_eval = DataLoader(dataset_eval, batch_size=10, shuffle=False, num_workers=0 if CFG.debug else 10)
def time_model_evaluation(model, loader):
    eval_start_time = time.time()
    result = evaluate(model, loader)
    eval_end_time = time.time()
    eval_duration_time = eval_end_time - eval_start_time
    print(result)
    print("Evaluate total time (seconds): {0:.1f}".format(eval_duration_time))

# Evaluate the original FP32 BERT model
time_model_evaluation(model, loader_eval)

In [39]:
device = torch.device('cpu')
quantized_model = quantized_model.to(device)
time_model_evaluation(quantized_model, loader_eval)

100%|██████████| 338/338 [35:51<00:00,  6.37s/it]


(0.8662550425948506, 0.9842595831141752)
Evaluate total time (seconds): 2152.1


In [None]:
import torch.nn as nn
import torch.quantization as quant

qconfig = quant.QConfig(
    activation=quant.MinMaxObserver.with_args(dtype=torch.qint8),
    weight=quant.default_per_channel_weight_observer,
    weight_prepare=quant.default_weight_preparation,
    bias=quant.default_bias_observer,
    dtype=torch.qint8
)
qconfig_dict = {"": qconfig}
qconfig_dict[nn.Linear] = quant.QConfig(weight=quant.default_weight_observer)



In [40]:
import argparse
parser = argparse.ArgumentParser(description='PyTorch Kaggle Bird Implementation')
parser.add_argument('--batch_size', type=int, default=10, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--lr', type=float, default=0.00001, metavar='LR',
                    help='learning rate (default: 0.0002)')
parser.add_argument('--max_epoch', type=int, default=25, metavar='N',
                    help='how many epochs')
parser.add_argument('--experiment_name', type=str, default='efficient_visual',
                    help='experiment name')
parser.add_argument('--n_trials', type=int, default=20,
                    help='number of trials')
parser.add_argument('--best_acc', type=float, default=0,
                    help='number of trials')
parser.add_argument('--model_name', type=str, default='efficient', choices=['beats', 'ast', 'musicnn', 'efficient'])
parser.add_argument('--eval_step', type=int, default=1)
parser.add_argument('--debug', action='store_true')
parser.add_argument('--model_path', type=str, default='experiments/ast/trial_0/ast.pth')
parser.add_argument('--fold', type=int, default=2)
args = parser.parse_args(['--best_acc', '0'])

In [42]:
import torch.nn as nn
import torch.quantization as quant
from optuna_utils.models import ASTagModel

config = ASTConfig() 
model = ASTagModel(config=config, train_config=args)
state_dict = torch.load('experiments/ast_9layer/trial_1/ast.pth', map_location=device)
model = model.to(device)
model.load_state_dict(state_dict)

dataset_train = ASTDataset(df, fold=4, mode='train')
loader_train = DataLoader(dataset_train, batch_size=CFG.batch_size, shuffle=True, num_workers=0 if CFG.debug else 10)
dataset_eval = ASTDataset(df, fold=4, mode='eval')
loader_eval = DataLoader(dataset_eval, batch_size=CFG.batch_size, shuffle=False, num_workers=0 if CFG.debug else 10)

In [43]:
from transformers.optimization import get_cosine_schedule_with_warmup
total_samples = dataset_train.__len__()
num_warmup_steps = (total_samples // args.batch_size) * 2
num_total_steps = (total_samples // args.batch_size) * args.max_epoch
lr_scheduler = get_cosine_schedule_with_warmup(model.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_total_steps)

In [45]:

qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

qconfig_dict = {"": qconfig}
qconfig_dict[nn.Linear] = quant.QConfig(weight=quant.default_weight_observer)

prepared_model = quant.prepare_qat(
    model.train(),
    qconfig_dict=qconfig_dict,
    modules_to_fuse=[nn.Linear]
)

TypeError: __new__() missing 1 required positional argument: 'activation'