# FEATURE EXTRACTION

This tutorial shows how to use our conformer-based model to extract features from the encoder.

**Note** To run this tutorial, please make sure you are in tutorials folder.

In [None]:
import sys
sys.path.insert(0, "../")

In [None]:
import os
import torch

In [None]:
import argparse
parser = argparse.ArgumentParser()
args, _ = parser.parse_known_args(args=[])

## 1. Build our model

In [None]:
from espnet.nets.pytorch_backend.e2e_asr_conformer import E2E
from pytorch_lightning import LightningModule
from datamodule.transforms import TextTransform

In [None]:
class ModelModule(LightningModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.save_hyperparameters(args)

        self.modality = args.modality
        self.text_transform = TextTransform()
        self.token_list = self.text_transform.token_list

        self.model = E2E(len(self.token_list), self.modality, ctc_weight=getattr(args, "ctc_weight", 0.1))

    def forward(self, x):
        x = self.model.frontend(x.unsqueeze(0))
        x = self.model.proj_encoder(x)
        x, _ = self.model.encoder(x, None)
        x = x.squeeze(0)
        return x

## 2. Download a pre-trained checkpoint

In [None]:
!wget http://www.doc.ic.ac.uk/~pm4115/autoAVSR/vsr_trlrs3_base.pth -O ./vsr_trlrs3_base.pth
model_path = "./vsr_trlrs3_base.pth"

## 3. Load weights from the checkpoint

In [None]:
setattr(args, 'modality', 'video')
model = ModelModule(args)
ckpt = torch.load(model_path, map_location=lambda storage, loc: storage)
model.model.load_state_dict(ckpt)
model.freeze()

## 4. Use the pre-trained model to extract features

A placeholder x with a shape of (length, num_channel, height, width) is used to represent the input tensor in the lip-reading model.

In [None]:
x = torch.randn((10, 1, 88, 88))
with torch.inference_mode():
    y = model(x)
print(y.size())