In [None]:
from s3prl.nn import S3PRLUpstream
import torch
import pytorch_lightning as pl
from typing import Any, Dict, Union, List


In [None]:
class S3PRLUpstreamMLPDownstrem(pl.LightningModule):
    def __init__(self, state, upstream='data2vec', layer=-1, hidden_layers=2, hidden_dim=128):
        super().__init__()
        self.upstream = S3PRLUpstream(upstream)
        self.mapping = state['class_map']
        upstream_dim = self.upstream.hidden_sizes[0]
        layer_dims = [upstream_dim] + [hidden_dim]*hidden_layers
        self.net = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(dim_in, dim_out), torch.nn.ReLU()) for dim_in, dim_out in zip(layer_dims[:-1],layer_dims[1:])])
        if isinstance(layer, int):
            layer = [layer]
        if layer == 'all':
            layer = list(range(len(self.upstream.hidden_sizes)))
        self.avg_weights = torch.nn.Parameter(torch.ones(len(layer),))
        self.layer = layer
        self.out_layer = torch.nn.Linear(layer_dims[-1],len(self.mapping))

    def forward(self, x):
        with torch.no_grad():
            hidden, last = upstream(x['wav'], wavs_len=x['wavs_len'])
        hidden = torch.stack(hidden).transpose(0,1)
        print(hidden.shape)
        w = torch.nn.functional.softmax(self.avg_weights, dim=0)
        avg_hidden = torch.sum(hidden[:,self.layer]*w[None,:,None,None],dim=1)
        y = self.out_layer(self.net(avg_hidden))
        return y

    def training_step(self, batch, batch_idx):
        yhat=self(batch)
        y=batch['class_id']
        return torch.nn.functional.cross_entropy(yhat,y)

    def validation_step(self, batch):

    def configure_optimizers(self):
        return torch.optim.Adam(
        

In [None]:
class S3PRLUpstreamMLPDownstrem(pl.LightningModule):
    def __init__(
        self,
        state: Dict[str, Any],
        upstream: str = 'data2vec',
        upstream_layers_output_to_use: Union[str, List[int], int] = -1,
        hidden_layers: int = 2,
        hidden_dim: int = 128,
    ):
        super().__init__()
        self.mapping = state['class_map']

        self.upstream = S3PRLUpstream(upstream)
        upstream_dim = self.upstream.hidden_sizes[0]

        layer_dims = [upstream_dim] + [hidden_dim] * hidden_layers

        self.downstream = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(dim_in, dim_out), torch.nn.ReLU()) for dim_in, dim_out in zip(layer_dims[:-1],layer_dims[1:])])
        self.out_layer = torch.nn.Linear(layer_dims[-1],len(self.mapping)) # FIXME: add this at the end of the downstream?

        if isinstance(upstream_layers_output_to_use, int):
            upstream_layers_output_to_use = [upstream_layers_output_to_use]
        if upstream_layers_output_to_use == 'all':
            upstream_layers_output_to_use = list(range(len(self.upstream.hidden_sizes)))
        self.upstream_layers_output_to_use = upstream_layers_output_to_use

        self.avg_weights = torch.nn.Parameter(torch.ones(len(upstream_layers_output_to_use),))

    def forward(self, x):
        with torch.no_grad():
            hidden, _ = self.upstream(x['wav'], wavs_len=x['wavs_len'])
        hidden = torch.stack(hidden).transpose(0,1)

        w = torch.nn.functional.softmax(self.avg_weights, dim=0)
        
        avg_hidden = torch.sum(hidden[:,self.upstream_layers_output_to_use]*w[None,:,None,None],dim=1)
        
        return self.out_layer(self.downstream(avg_hidden))

    def training_step(self, batch, batch_idx):
        yhat = self(batch)
        y = batch['class_id']
        return torch.nn.functional.cross_entropy(yhat,y)


In [None]:
import joblib

state=joblib.load('../speech_hypertuning/experiments/experiment_lr/test_load_dataset/state.pkl')

In [None]:
model = S3PRLUpstreamMLPDownstrem(state, upstream_layers_output_to_use='all')

In [None]:
model({'wav':torch.randn((2,32000)),'wavs_len':torch.tensor([32000,16000])}).shape

In [None]:
hidden, last = upstream(torch.randn((2,32000)), wavs_len=torch.tensor([32000,16000]))

In [None]:
hidden[0].shape

In [None]:
hidden = torch.stack(hidden).transpose(0,1)

In [None]:
hidden[:,[2,3,5]].shape

In [None]:
BS,L,T,D
x,L,x,x