# LightHuBERT

The LightHuBERT tutorial support features as follows.

1. Load different checkpoints to the LightHuBERT architecture, such as base supernet, small supernet, and stage 1.
2. Sample a subnet and set the subnet.
3. Conduct inference of the subnet with a random input.

In [1]:
import sys
sys.path.append("../")

import torch
from lighthubert import LightHuBERT, LightHuBERTConfig

device="cuda:2"
wav_input_16khz = torch.randn(1,10000).to(device)

## LightHuBERT Base Supernet

In [2]:
# checkpoint = torch.load('/path/to/lighthubert_base.pt')
checkpoint = torch.load('/workspace/projects/lighthubert/checkpoints/lighthubert_base.pt')
cfg = LightHuBERTConfig(checkpoint['cfg']['model'])
cfg.supernet_type = 'base'
model = LightHuBERT(cfg)
model = model.to(device)
model = model.eval()
print(model.load_state_dict(checkpoint['model'], strict=False))

# (optional) set a subnet
subnet = model.supernet.sample_subnet()
model.set_sample_config(subnet)
params = model.calc_sampled_param_num()
print(f"subnet (Params {params / 1e6:.0f}M) | {subnet}")

# extract the the representation of last layer
rep = model.extract_features(wav_input_16khz)[0]

# extract the the representation of each layer
hs = model.extract_features(wav_input_16khz, ret_hs=True)[0]

print(f"Representation at bottom hidden states: {torch.allclose(rep, hs[-1])}")

2022-03-29 20:52:53 | INFO | lighthubert.lighthubert | predicting heads: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
2022-03-29 20:52:53 | INFO | lighthubert.lighthubert | search space (6,530,347,008 subnets): {'atten_dim': [512, 640, 768], 'embed_dim': [512, 640, 768], 'ffn_ratio': [3.5, 4.0], 'heads_num': [8, 10, 12], 'layer_num': [12]}
2022-03-29 20:52:53 | INFO | lighthubert.lighthubert | min subnet (41 Params): {'atten_dim': [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512], 'embed_dim': 512, 'ffn_embed': [1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792], 'heads_num': [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8], 'layer_num': 12, 'slide_wsz': ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']}
2022-03-29 20:52:53 | INFO | lighthubert.lighthubert | max subnet (94 Params): {'atten_dim': [768, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768], 'embed_dim': 768, 'ffn_embed': [3072, 3072, 3072

_IncompatibleKeys(missing_keys=[], unexpected_keys=['label_embs_concat', 'final_proj.weight', 'final_proj.bias'])
subnet (Params 85M) | {'atten_dim': [512, 640, 512, 640, 768, 512, 640, 512, 512, 640, 768, 640], 'embed_dim': 768, 'ffn_embed': [3072, 2688, 3072, 2688, 2688, 3072, 2688, 3072, 3072, 2688, 3072, 2688], 'heads_num': [8, 10, 8, 10, 12, 8, 10, 8, 8, 10, 12, 10], 'layer_num': 12, 'slide_wsz': ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']}
Representation at bottom hidden states: True


## LightHuBERT Small Supernet

In [3]:
# checkpoint = torch.load('/path/to/lighthubert_small.pt')
checkpoint = torch.load('/workspace/projects/lighthubert/checkpoints/lighthubert_small.pt')
cfg = LightHuBERTConfig(checkpoint['cfg']['model'])
cfg.supernet_type = 'small'
model = LightHuBERT(cfg)
model = model.to(device)
model = model.eval()
print(model.load_state_dict(checkpoint['model'], strict=False))

# (optional) set a subnet
subnet = model.supernet.sample_subnet()
model.set_sample_config(subnet)
params = model.calc_sampled_param_num()
print(f"subnet (Params {params / 1e6:.0f}M) | {subnet}")

# extract the the representation of last layer
rep = model.extract_features(wav_input_16khz)[0]

# extract the the representation of each layer
hs = model.extract_features(wav_input_16khz, ret_hs=True)[0]

print(f"Representation at bottom hidden states: {torch.allclose(rep, hs[-1])}")

2022-03-29 20:53:04 | INFO | lighthubert.lighthubert | LightHuBERT Config: <lighthubert.lighthubert.LightHuBERTConfig object at 0x7f078337e430>
2022-03-29 20:53:06 | INFO | lighthubert.lighthubert | predicting heads: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
2022-03-29 20:53:06 | INFO | lighthubert.lighthubert | search space (951,892,141,473 subnets): {'atten_dim': [256, 384, 512], 'embed_dim': [256, 384, 512], 'ffn_ratio': [3.0, 3.5, 4.0], 'heads_num': [4, 6, 8], 'layer_num': [10, 11, 12]}
2022-03-29 20:53:06 | INFO | lighthubert.lighthubert | min subnet (11 Params): {'atten_dim': [256, 256, 256, 256, 256, 256, 256, 256, 256, 256], 'embed_dim': 256, 'ffn_embed': [768, 768, 768, 768, 768, 768, 768, 768, 768, 768], 'heads_num': [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], 'layer_num': 10, 'slide_wsz': ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']}
2022-03-29 20:53:06 | INFO | lighthubert.lighthubert | max subnet (44 Params): {'atten_dim': [512, 512, 

_IncompatibleKeys(missing_keys=[], unexpected_keys=['label_embs_concat', 'final_proj.weight', 'final_proj.bias'])
subnet (Params 38M) | {'atten_dim': [256, 512, 256, 256, 384, 384, 256, 512, 256, 512, 384, 512], 'embed_dim': 512, 'ffn_embed': [1792, 1536, 1536, 2048, 1792, 1536, 1536, 2048, 1792, 2048, 2048, 2048], 'heads_num': [4, 8, 4, 4, 6, 6, 4, 8, 4, 8, 6, 8], 'layer_num': 12, 'slide_wsz': ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']}
Representation at bottom hidden states: True


## LightHuBERT Stage 1

In [4]:
# checkpoint = torch.load('/path/to/lighthubert_stage1.pt')
checkpoint = torch.load('/workspace/projects/lighthubert/checkpoints/lighthubert_stage1.pt')
cfg = LightHuBERTConfig(checkpoint['cfg']['model'])
cfg.supernet_type = 'base'
model = LightHuBERT(cfg)
model = model.to(device)
model = model.eval()
print(model.load_state_dict(checkpoint['model'], strict=False))

# (optional) set a subnet
subnet = model.supernet.max_subnet
model.set_sample_config(subnet)
params = model.calc_sampled_param_num()
print(f"subnet (Params {params / 1e6:.0f}M) | {subnet}")

# extract the the representation of last layer
rep = model.extract_features(wav_input_16khz)[0]

# extract the the representation of each layer
hs = model.extract_features(wav_input_16khz, ret_hs=True)[0]

print(f"Representation at bottom hidden states: {torch.allclose(rep, hs[-1])}")

2022-03-29 20:53:20 | INFO | lighthubert.lighthubert | LightHuBERT Config: <lighthubert.lighthubert.LightHuBERTConfig object at 0x7f097862fe80>
2022-03-29 20:53:23 | INFO | lighthubert.lighthubert | predicting heads: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
2022-03-29 20:53:23 | INFO | lighthubert.lighthubert | search space (6,530,347,008 subnets): {'atten_dim': [512, 640, 768], 'embed_dim': [512, 640, 768], 'ffn_ratio': [3.5, 4.0], 'heads_num': [8, 10, 12], 'layer_num': [12]}
2022-03-29 20:53:23 | INFO | lighthubert.lighthubert | min subnet (41 Params): {'atten_dim': [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512], 'embed_dim': 512, 'ffn_embed': [1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792, 1792], 'heads_num': [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8], 'layer_num': 12, 'slide_wsz': ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']}
2022-03-29 20:53:23 | INFO | lighthubert.lighthubert | max 

_IncompatibleKeys(missing_keys=[], unexpected_keys=['label_embs_concat', 'final_proj.weight', 'final_proj.bias'])
subnet (Params 94M) | {'atten_dim': [768, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768, 768], 'embed_dim': 768, 'ffn_embed': [3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072], 'heads_num': [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12], 'layer_num': 12, 'slide_wsz': ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']}
Representation at bottom hidden states: True
