In [1]:
from datetime import datetime
from typing import Any, Dict, List, Optional, Union

import torch
import torchmetrics
from lightning import LightningModule
from s3prl.nn import S3PRLUpstream


class S3PRLUpstreamMLPDownstreamForCls(LightningModule):
    def __init__(
        self,
        state: Dict[str, Any],
        upstream: str = 'wavlm_base_plus',
        upstream_layers_output_to_use: Union[str, List[int], int] = 'all',
        hidden_layers: int = 2,
        hidden_dim: int = 128,
        optimizer: Optional[Any] = None,
        lr_scheduler: Optional[Any] = None,
    ):
        super().__init__()
        self.opt_state = state
        self.optimizer = optimizer if optimizer is not None else torch.optim.Adam
        self.lr_scheduler = lr_scheduler
        self.mapping = state['speaker_id_mapping']
        self.num_classes = len(self.mapping)

        self.upstream = S3PRLUpstream(upstream)
        upstream_dim = self.upstream.hidden_sizes[0]

        layer_dims = [upstream_dim] + [hidden_dim] * hidden_layers

        self.downstream = torch.nn.Sequential(
            *[
                torch.nn.Sequential(torch.nn.Linear(dim_in, dim_out), torch.nn.ReLU())
                for dim_in, dim_out in zip(layer_dims[:-1], layer_dims[1:])
            ]
        )
        self.out_layer = torch.nn.Linear(
            layer_dims[-1], self.num_classes
        )  # FIXME: add this at the end of the downstream?

        if isinstance(upstream_layers_output_to_use, int):
            upstream_layers_output_to_use = [upstream_layers_output_to_use]
        elif upstream_layers_output_to_use == 'all':
            upstream_layers_output_to_use = list(range(len(self.upstream.hidden_sizes)))

        self.upstream_layers_output_to_use = upstream_layers_output_to_use

        self.avg_weights = torch.nn.Parameter(
            torch.ones(
                len(upstream_layers_output_to_use),
            )
        )

        self.accuracy_top1 = torchmetrics.classification.Accuracy(
            task="multiclass", num_classes=self.num_classes
        )
        self.accuracy_top5 = torchmetrics.classification.Accuracy(
            task="multiclass", num_classes=self.num_classes, top_k=5
        )

    def forward(self, x: torch.Tensor):

        hidden = self.forward_upstream(x)

        w = torch.nn.functional.softmax(self.avg_weights, dim=0)

        avg_hidden = torch.sum(
            hidden[:, self.upstream_layers_output_to_use] * w[None, :, None],
            dim=1,
        )

        return self.out_layer(self.downstream(avg_hidden))

    def forward_upstream(self, x) -> torch.Tensor:
        if (
            not x.get("upstream_embedding_precalculated").all().item()
        ):  # Check if all instances have the embedding precalculated
            with torch.no_grad():
                hidden, _ = self.upstream(x['wav'], wavs_len=x['wav_lens'])
            hidden = torch.stack(hidden).transpose(0, 1)
        else:
            hidden = x['upstream_embedding']
        return hidden

    def training_step(
        self,
        batch: torch.Tensor,
        batch_idx: int,  # pylint: disable=unused-argument
    ):
        losses = self.calculate_loss(batch)
        self.log_results(losses, 'train')
        return losses

    def validation_step(
        self,
        batch: torch.Tensor,
        batch_idx: int,  # pylint: disable=unused-argument
    ):
        losses = self.calculate_loss(batch)
        self.log_results(losses, 'val')

    def test_step(
        self,
        batch: torch.Tensor,
        batch_idx: int,  # pylint: disable=unused-argument
    ) -> None:
        losses = self.calculate_loss(batch)

        out = self(batch)
        yhat = out.squeeze()
        y = batch['class_id']
        accuracy_top1 = self.accuracy_top1(yhat, y)
        accuracy_top5 = self.accuracy_top5(yhat, y)

        self.log_results(losses, 'test')
        self.log_results(accuracy_top1, 'test', 'accuracy_top1')
        self.log_results(accuracy_top5, 'test', 'accuracy_top5')

    def calculate_loss(self, x: torch.Tensor):
        out = self(x)
        yhat = out.squeeze()
        y = x['class_id']

        if len(yhat.shape) == 1:
            yhat = yhat.unsqueeze(dim=0)

        return torch.nn.functional.cross_entropy(yhat, y)

    def log_results(self, losses, prefix, metric="loss") -> None:
        log_loss = {
            "time": int(datetime.now().strftime('%y%m%d%H%M%S')),
            metric: losses,
        }
        self.log_dict({'{}_{}'.format(prefix, k): v for k, v in log_loss.items()})

    def configure_optimizers(
        self,
    ) -> Dict[str, Any]:
        optimizer = self.optimizer(params=self.parameters())
        optimizer_config = {"optimizer": optimizer}
        if self.lr_scheduler is not None:
            lr_scheduler_config = {
                "scheduler": self.lr_scheduler(optimizer),
                "monitor": "val_loss",
                "interval": "epoch",
                "frequency": 1,
            }
            optimizer_config['lr_scheduler'] = lr_scheduler_config
        return optimizer_config

    def set_optimizer_state(self, state: Dict[str, Any]) -> None:
        self.opt_state = state


  from .autonotebook import tqdm as notebook_tqdm
  torchaudio.set_audio_backend("sox_io")
ESPnet is not installed, cannot use espnet_hubert upstream


## Make inference from checkpoint

In [5]:
import joblib

In [6]:
state = joblib.load("../speech_hypertuning/experiments/batch_size_vs_learning_rate/batch_size_1-lr_0.000001/state.pkl")

In [4]:
model = S3PRLUpstreamMLPDownstreamForCls(
    state=state,
    hidden_layers=1,
    hidden_dim=4096,
    optimizer=torch.optim.Adam,
)
checkpoint = torch.load("../speech_hypertuning/experiments/batch_size_vs_learning_rate/batch_size_1-lr_0.000001/checkpoints/epoch=52-step=100700.ckpt")
model.load_state_dict(checkpoint['state_dict'])



<All keys matched successfully>

In [5]:
upstream_embedding = torch.load("/home/eernst/Voxceleb1/avg_embeddings/id10020_1elTcNGC3q8_00022.pt")

In [6]:
upstream_embedding.shape

torch.Size([13, 768])

In [21]:
y = model({"upstream_embedding": upstream_embedding.unsqueeze(dim=0), "upstream_embedding_precalculated": torch.Tensor([True])})
y.shape

torch.Size([1, 100])

## Delete upstream from checkpoint

In [9]:
from copy import deepcopy

In [10]:
checkpoint_clean = deepcopy(checkpoint)

In [11]:
checkpoint_clean.keys()

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecision'])

In [12]:
checkpoint_clean['state_dict'].keys()

odict_keys(['avg_weights', 'upstream.upstream.model.mask_emb', 'upstream.upstream.model.feature_extractor.conv_layers.0.0.weight', 'upstream.upstream.model.feature_extractor.conv_layers.0.2.weight', 'upstream.upstream.model.feature_extractor.conv_layers.0.2.bias', 'upstream.upstream.model.feature_extractor.conv_layers.1.0.weight', 'upstream.upstream.model.feature_extractor.conv_layers.2.0.weight', 'upstream.upstream.model.feature_extractor.conv_layers.3.0.weight', 'upstream.upstream.model.feature_extractor.conv_layers.4.0.weight', 'upstream.upstream.model.feature_extractor.conv_layers.5.0.weight', 'upstream.upstream.model.feature_extractor.conv_layers.6.0.weight', 'upstream.upstream.model.post_extract_proj.weight', 'upstream.upstream.model.post_extract_proj.bias', 'upstream.upstream.model.encoder.pos_conv.0.bias', 'upstream.upstream.model.encoder.pos_conv.0.weight_g', 'upstream.upstream.model.encoder.pos_conv.0.weight_v', 'upstream.upstream.model.encoder.layers.0.self_attn.grep_a', 'up

In [13]:
model_state_keys = list(checkpoint_clean['state_dict'].keys())
for key in model_state_keys:
    if key.startswith("upstream"):
        del checkpoint_clean['state_dict'][key]

In [14]:
checkpoint_clean['state_dict'].keys()

odict_keys(['avg_weights', 'downstream.0.0.weight', 'downstream.0.0.bias', 'out_layer.weight', 'out_layer.bias'])

In [15]:
checkpoint_clean

{'epoch': 52,
 'global_step': 100700,
 'pytorch-lightning_version': '2.2.0.post0',
 'state_dict': OrderedDict([('avg_weights',
               tensor([0.9748, 0.9834, 0.9907, 0.9892, 0.9957, 1.0095, 1.0282, 1.0282, 1.0211,
                       1.0268, 1.0337, 1.0138, 0.9327], device='cuda:0')),
              ('downstream.0.0.weight',
               tensor([[-0.0295,  0.0338, -0.0306,  ...,  0.0119,  0.0276,  0.0250],
                       [ 0.0399,  0.0216,  0.0165,  ..., -0.0096,  0.0119,  0.0121],
                       [ 0.0121,  0.0246,  0.0221,  ...,  0.0051,  0.0272,  0.0275],
                       ...,
                       [ 0.0045,  0.0237, -0.0173,  ..., -0.0288, -0.0116,  0.0337],
                       [ 0.0079, -0.0195, -0.0240,  ..., -0.0047,  0.0133,  0.0240],
                       [ 0.0082,  0.0203, -0.0200,  ..., -0.0167, -0.0037,  0.0164]],
                      device='cuda:0')),
              ('downstream.0.0.bias',
               tensor([-0.0201,  0.0015,  0.0

In [16]:
torch.save(checkpoint_clean, "model_clean.pt")

### Test inference with cached upstream 

In [17]:
new_model = S3PRLUpstreamMLPDownstreamForCls(
    state=state,
    hidden_layers=1,
    hidden_dim=4096,
    optimizer=torch.optim.Adam,
)
new_checkpoint = torch.load("../speech_hypertuning/experiments/batch_size_vs_learning_rate/batch_size_1-lr_0.000001/checkpoints/epoch=52-step=100700.ckpt")

In [18]:
for key, value in checkpoint['state_dict'].items():
    if key.startswith("upstream"):
        new_checkpoint['state_dict'][key] = value

In [19]:
new_model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [22]:
y = new_model({"upstream_embedding": upstream_embedding.unsqueeze(dim=0), "upstream_embedding_precalculated": torch.Tensor([True])})
y.shape

torch.Size([1, 100])

# Clean memory from project

In [2]:
from pathlib import Path

In [3]:
project_path = "../speech_hypertuning/experiments/batch_size_vs_learning_rate/"

In [8]:
model = S3PRLUpstreamMLPDownstreamForCls(
    state=state,
    hidden_layers=1,
    hidden_dim=4096,
    optimizer=torch.optim.Adam,
) #TODO: take from state
original_checkpoint = torch.load("../speech_hypertuning/experiments/batch_size_vs_learning_rate/batch_size_1-lr_0.000001/checkpoints/epoch=999-step=1900000.ckpt")



In [12]:
states_paths = Path(project_path).rglob('*state.pkl')
for path in states_paths:
    experiment_name = path.parent.name
    print(f"Cleaning {experiment_name}") 
    state = joblib.load(path)
    checkpoint_paths = list(path.parent.rglob('*checkpoints/*.ckpt'))
    if len(checkpoint_paths) == 1:
        if not 'test_metrics' in state:
            print("Not test metrics in state, skipping")
            continue
        checkpoint_path = str(checkpoint_paths[0])
        checkpoint = torch.load(checkpoint_path)

        # Delete
        model_state_keys = list(checkpoint['state_dict'].keys())
        if not any(key.startswith("upstream") for key in model_state_keys):
            print(f"{experiment_name} already cleaned")
            continue
        for key in model_state_keys:
            if key.startswith("upstream"):
                del checkpoint['state_dict'][key]
                
        # Replace checkpoint
        torch.save(checkpoint, checkpoint_path)

        # Test saved
        new_checkpoint = torch.load(checkpoint_path)
        for key, value in original_checkpoint['state_dict'].items():
            if key.startswith("upstream"):
                new_checkpoint['state_dict'][key] = value
                
        msg = model.load_state_dict(new_checkpoint['state_dict'])
        if str(msg) == "<All keys matched successfully>":
            print(f"Succesfully cleaned {experiment_name}")
        else:
            raise ValueError(msg)

Cleaning batch_size_1-lr_0.0005
batch_size_1-lr_0.0005 already cleaned
Cleaning batch_size_128-lr_0.000001
Succesfully cleaned batch_size_128-lr_0.000001
Cleaning batch_size_8-lr_0.1
batch_size_8-lr_0.1 already cleaned
Cleaning batch_size_1900-lr_0.000005
batch_size_1900-lr_0.000005 already cleaned
Cleaning batch_size_16-lr_0.05
batch_size_16-lr_0.05 already cleaned
Cleaning batch_size_1900-lr_0.00005
batch_size_1900-lr_0.00005 already cleaned
Cleaning batch_size_4-lr_0.001
batch_size_4-lr_0.001 already cleaned
Cleaning batch_size_256-lr_0.000005
batch_size_256-lr_0.000005 already cleaned
Cleaning batch_size_1024-lr_0.05
batch_size_1024-lr_0.05 already cleaned
Cleaning batch_size_2-lr_0.5
batch_size_2-lr_0.5 already cleaned
Cleaning batch_size_128-lr_0.0001
batch_size_128-lr_0.0001 already cleaned
Cleaning batch_size_512-lr_0.5
batch_size_512-lr_0.5 already cleaned
Cleaning batch_size_512-lr_1
batch_size_512-lr_1 already cleaned
Cleaning batch_size_64-lr_0.5
batch_size_64-lr_0.5 alread