# TAG-XAI
## Text-to-Audio eXplainable AI

In [None]:
import sys
import os

os.environ["CUDA_VISIBLE_DEVICES"]= '' # 'your gpu id

project_path = '' # 'your project path
sys.path.append(project_path)

In [None]:
from audiocraft.models import AudioGen
from audiocraft.models import Explainer
from audiocraft.models import MaskGenerator

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
audiogen = AudioGen.get_pretrained('facebook/audiogen-medium')
explainer = Explainer(audiogen, audiogen.lm)

In [None]:
explainer.duration = 5
explainer.generation_params['use_sampling'] = True

### Generate audio

In [None]:
description = 'A person is walking on a gravel road.'

In [None]:
with torch.no_grad():
    sequences, _, outs = explainer.generate_with_mask([description])
    sequences = sequences.detach()
    cond = outs[0].detach()

In [None]:
from audiocraft.utils.notebook import display_audio
audio = explainer.token_to_audio(sequences)

display_audio(audio, sample_rate=16000)

### Explain Text-Audio Pair

In [None]:
emb = explainer.get_token_emb(description).squeeze().detach()

In [None]:
maskGenerators = nn.ModuleList([MaskGenerator(emb.shape[0], emb.shape[1]).to(explainer.device) for _ in range(cond.shape[0])])

In [None]:
for mask_gen in maskGenerators:
    mask_gen.hard = False

In [None]:
epochs = 50
lr = 1E-3

In [None]:
optimizer = optim.Adam(maskGenerators.parameters(), lr=lr, weight_decay=1e-3)

In [None]:
maskGenerators.train()
print()

In [None]:
def gen_mask(maskGenerators, emb):
    params=[]
    reparams=[]
    for mask_gen in maskGenerators:
        x, reparam = mask_gen(emb)
        params.append(x.squeeze())
        reparams.append(reparam.squeeze())

    params = torch.stack(params, dim=0)
    reparams = torch.stack(reparams, dim=0)
    return params, reparams

In [None]:
cos = nn.CosineSimilarity(dim=-1, eps=1e-6)

In [None]:
EPS = 1E-6
beta = 1E-3
gamma = 1E-1

In [None]:
for epoch in range(epochs):
    params, reparams = gen_mask(maskGenerators, emb)
    _, outs_F = explainer(sequences[:, :, :-1].permute(2, 1, 0), description, reparams)
    cond_F, _ = outs_F.split(outs_F.shape[0]//2, dim=0)

    _, outs_CF = explainer(sequences[:, :, :-1].permute(2, 1, 0), description, 1-reparams)
    cond_CF, _ = outs_CF.split(outs_CF.shape[0]//2, dim=0)

    loss_F = - cos(cond, cond_F.squeeze()).sum()
    loss_CF = cos(cond, cond_CF.squeeze()).sum()

    l1 = abs(params).sum()
    l2 = torch.sqrt((params**2).sum())

    loss = loss_F + loss_CF + (l1 * beta) + (l2 * gamma)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f'epoch({epoch}) :{loss.item():.2f} {loss_F.item():.2f} {loss_CF.item():.2f} {l1.item():.2f} {l2.item():.2f}')

In [None]:
for tkn, m in zip(explainer.lm.condition_provider.conditioners.description.t5_tokenizer.tokenize(description),reparams.mean(dim=0).tolist()):
    print(f'{tkn} : {m:.2f}')

### factual mask audio generating

In [None]:
explainer.generation_params['use_sampling'] = True
explainer.generation_params['top_k'] = 250
maskGenerators.eval()

with torch.no_grad():
    params, reparams = gen_mask(maskGenerators, emb)
    sequences_F, _, _ = explainer.generate_with_mask([description], mask=reparams)
    audio_F = explainer.token_to_audio(sequences_F)

In [None]:
display_audio(audio_F, sample_rate=16000)

In [None]:
torchaudio.save('audio_F.wav', audio_F.squeeze(0).detach().cpu(), 16000)

### counterfactual mask audio generating

In [None]:
explainer.generation_params['use_sampling'] = True
explainer.generation_params['top_k'] = 250
maskGenerators.eval()

with torch.no_grad():
    params, reparams = gen_mask(maskGenerators, emb)
    sequences_CF, _, _ = explainer.generate_with_mask([description], mask=1-reparams)
    audio_CF = explainer.token_to_audio(sequences_CF)

In [None]:
display_audio(audio_CF, sample_rate=16000)

In [None]:
torchaudio.save('audio_CF.wav', audio_CF.squeeze(0).detach().cpu(), 16000)