In [1]:
from multitask_missing_q10 import get_loader, MySet, binary_cross_entropy_with_logits, TemporalDecay, FeatureRegression, to_var
import torch
from torch import nn
from torch.nn import functional as F   
from torch.autograd import Variable
from lightning.pytorch import seed_everything
seed_everything(42)

SEQ_LEN = 40                           # number of period in the ts, t = 1, 2, 3, 4, 5.
RNN_HID_SIZE = 32                     # hidden node of the rnn 
batch_size = 32
model_name = 'BRITS_ATT' # RITS
question = 'feeling_lately'
open_face = 'eye_gaze'
epochs = 100
#N_SERIES = 12                          # number of series Rd, 12:eye, 136:landmark,14:action unit
lr = 1e-3
repetitions = 1
ratio_missing = 0.05
type_missing = 'Random' # CMV
rnn_name = 'LSTM' # GRU
experiment_name = 'exp01'

if open_face=='action_unit':
    N_SERIES = 14  # 14, for action unit
elif open_face=='eye_gaze':
    N_SERIES = 12
elif open_face=='landmark':
    N_SERIES = 136
elif open_face=='all':
    N_SERIES = (14+12+136)


[rank: 0] Seed set to 42


In [2]:
class Model_rits_att(nn.Module):
    def __init__(self, rnn_name='LSTM'):
        super(Model_rits_att, self).__init__()
        self.rnn_name = rnn_name
        self.build()
        
        # Attention following AudiBert 
        self.W_s1 = nn.Linear(RNN_HID_SIZE, 350)
        self.W_s2 = nn.Linear(350, 30)
    def build(self):
        if self.rnn_name=='LSTM':
            self.rnn_cell = nn.LSTMCell(N_SERIES * 2, RNN_HID_SIZE)
        elif self.rnn_name=='GRU':
            self.rnn_cell = nn.GRUCell(N_SERIES * 2, RNN_HID_SIZE)

        self.temp_decay_h = TemporalDecay(input_size = N_SERIES, output_size = RNN_HID_SIZE, diag = False)
        self.temp_decay_x = TemporalDecay(input_size = N_SERIES, output_size = N_SERIES, diag = True)

        self.hist_reg = nn.Linear(RNN_HID_SIZE, N_SERIES)
        self.feat_reg = FeatureRegression(N_SERIES)

        self.weight_combine = nn.Linear(N_SERIES * 2, N_SERIES)

        self.dropout = nn.Dropout(p = 0.25)
        #self.out = nn.Linear(RNN_HID_SIZE, 1)
        self.out = nn.Linear(RNN_HID_SIZE*30, 1)
        
    def attention_rnn(self, rnn_output):
        #attn_weight_matrix = self.W_s2(F.tanh(self.W_s1(lstm_output)))
        attn_weight_matrix = self.W_s2(torch.tanh(self.W_s1(rnn_output)))
        attn_weight_matrix = attn_weight_matrix.permute(0, 2, 1)
        attn_weight_matrix = F.softmax(attn_weight_matrix, dim=2)
        return attn_weight_matrix
    
    def _assemble_input_for_training(self, data): 

        """
        Collate function for the BRITS dataloader.

        Args:
            data (List[Dict]): List of records containing time series data from BRITSDataFormat.

        Returns:
            Dict: A dictionary containing the collated data.

        Raises:
            AssertionError: If the required keys are not found in the input list.
        """
        # assuming data is a dict of tensors with keys: 'X', 'missing_mask', 'deltas', 'back_X', 'back_missing_mask', 'back_deltas', 'label'
        for k, v in data.items():
            if k == 'label':
                data[k] = v.long()
            else:
                data[k] = v.type_as(next(iter(self.W_s1.parameters())))
        final_dict = {
            'forward': {"X": data['X'], "missing_mask": data['missing_mask'], "deltas": data['deltas']}, #TODO: check if this is correct
            'backward': {"X": data['back_X'], "missing_mask": data['back_missing_mask'], "deltas": data['back_deltas']},
            'label': data['label']
        }
        return final_dict 
    
    def forward(self, data, direct, filter_train=False):
        is_train = data['is_train'].view(-1, 1)
        non_zero = torch.where(is_train == 1)[0] if filter_train else slice(None)

        # Original sequence with 24 time steps
        values = data[direct]['values'][non_zero]
        masks = data[direct]['masks'][non_zero]
        deltas = data[direct]['deltas'][non_zero]

        # to store historical hidden size from rnn
        H_rnn = torch.zeros(values.shape[0], SEQ_LEN, RNN_HID_SIZE)  # (batch, sequence, hiden_dize)

        evals = data[direct]['evals']
        eval_masks = data[direct]['eval_masks']

        labels = data['label'].view(-1, 1)[non_zero]

        h = Variable(torch.zeros((values.size()[0], RNN_HID_SIZE)))
        c = Variable(torch.zeros((values.size()[0], RNN_HID_SIZE)))

        x_loss = 0.0
        y_loss = 0.0

        imputations = []

        for t in range(SEQ_LEN):
            x = values[:, t, :] # every sample in batch, t'th time step, all features
            m = masks[:, t, :] # every sample in batch, t'th time step, all features
            d = deltas[:, t, :] # every sample in batch, t'th time step, all features

            gamma_h = self.temp_decay_h(d) # reshaped deltas to rnn_hidden_size
            gamma_x = self.temp_decay_x(d) # reshaped deltas to number time steps

            h = h * gamma_h

            x_h = self.hist_reg(h) # reshapes hidden state to (number of features)
            x_loss += torch.sum(torch.abs(x - x_h) * m) / (torch.sum(m) + 1e-5)
            # calculates avg absolute error for present predicted values (Based on hidden state) and actual values

            x_c =  m * x +  (1 - m) * x_h # this just puts the real values with the predicted values together

            z_h = self.feat_reg(x_c) # some? reshaping shape: (bs, n_features), all values
            x_loss += torch.sum(torch.abs(x - z_h) * m) / (torch.sum(m) + 1e-5) # some loss of real values and imputed values for the index of real values

            alpha = self.weight_combine(torch.cat([gamma_x, m], dim = 1)) # shape: (bs, n_features) all values

            c_h = alpha * z_h + (1 - alpha) * x_h
            x_loss += torch.sum(torch.abs(x - c_h) * m) / (torch.sum(m) + 1e-5)

            c_c = m * x + (1 - m) * c_h

            inputs = torch.cat([c_c, m], dim = 1)
            if self.rnn_name=='LSTM':
                h, c = self.rnn_cell(inputs, (h, c))         # h lstm: torch.Size([1, 32]
            elif self.rnn_name=='GRU':
                h = self.rnn_cell(inputs, h)                 # h GRU: torch.Size([1, 32]
            H_rnn[:,t,:] = h
            imputations.append(c_c.unsqueeze(dim = 1))

        imputations = torch.cat(imputations, dim = 1)

        # Attentions 
        attn_weight_matrix = self.attention_rnn(H_rnn)
        hidden_matrix = torch.bmm(attn_weight_matrix, H_rnn)
        attention_output = hidden_matrix.view(-1, hidden_matrix.size()[1]*hidden_matrix.size()[2])

        y_h = self.out(attention_output)
        y_loss = binary_cross_entropy_with_logits(y_h, labels, reduce = False)
        y_loss = torch.sum(y_loss * is_train[non_zero]) / (torch.sum(is_train[non_zero]) + 1e-5)

        y_h = torch.sigmoid(y_h)

        return {'loss_imp': x_loss / SEQ_LEN, 'loss_clf': y_loss * 0.1, "loss": x_loss/SEQ_LEN + y_loss*0.1,  
                'predictions': y_h,
                'imputations': imputations, 
                'label': labels, 'is_train': is_train,
                'evals': evals, 'eval_masks': eval_masks}

    def run_on_batch(self, data, optimizer):
        ret = self(data, direct = 'forward')

        if optimizer is not None:
            optimizer.zero_grad()
            ret['loss'].backward()
            optimizer.step()

        return ret

In [4]:
class Model_brits_att(nn.Module):
    def __init__(self, rnn_name='LSTM'):
        super(Model_brits_att, self).__init__()
        self.rnn_name = rnn_name
        self.build()

    def build(self):
        self.rits_f = Model_rits_att(self.rnn_name)
        self.rits_b = Model_rits_att(self.rnn_name)

    def forward(self, data):
        ret_f = self.rits_f(data, 'forward', filter_train=False)
        ret_b = self.reverse(self.rits_b(data, 'backward', filter_train=False))
        ret = self.merge_ret(ret_f, ret_b) #TODO finish making "2" so you can check if yloss is same for full vs nonzero, and if x_loss is same, gradient update same/different
        return ret
    
    def forward2(self, data):
        ret_f = self.rits_f.forward(data, 'forward', filter_train=True)
        ret_b = self.reverse(self.rits_b.forward(data, 'backward', filter_train=True))
        ret = self.merge_ret(ret_f, ret_b)
        return ret


    def merge_ret(self, ret_f, ret_b):
        loss_f = ret_f['loss']
        loss_b = ret_b['loss']
        loss_c = self.get_consistency_loss(ret_f['imputations'], ret_b['imputations'])

        loss = loss_f + loss_b + loss_c

        predictions = (ret_f['predictions'] + ret_b['predictions']) / 2
        imputations = (ret_f['imputations'] + ret_b['imputations']) / 2

        ret_f['loss'] = loss
        ret_f['predictions'] = predictions
        ret_f['imputations'] = imputations

        return ret_f

    def get_consistency_loss(self, pred_f, pred_b):
        loss = torch.pow(pred_f - pred_b, 2.0).mean()
        return loss

    def reverse(self, ret):
        def reverse_tensor(tensor_):
            if tensor_.dim() <= 1:
                return tensor_
            indices = range(tensor_.size()[1])[::-1]
            indices = Variable(torch.LongTensor(indices), requires_grad = False)

            #if torch.cuda.is_available():
            #    indices = indices.cuda()

            return tensor_.index_select(1, indices)

        for key in ret:
            ret[key] = reverse_tensor(ret[key])

        return ret

    def test_run_on_batch(self, data, optimizer):
        ret1 = self.forward(data)
        ret2 = self.forward2(data)
        if optimizer is not None:
            optimizer.zero_grad()
            ret1['loss_imp'].backward(retain_graph=True)
            gradients_loss1_imp = [param.grad.clone() for name, param in self.named_parameters() if param.grad is not None]
            
            optimizer.zero_grad()
            ret1['loss_clf'].backward(retain_graph=True)
            gradients_loss1_clf = [param.grad.clone() for name, param in self.named_parameters() if param.grad is not None]
            optimizer.zero_grad()

            ret2['loss_imp'].backward(retain_graph=True)
            gradients_loss2_imp = [param.grad.clone() for name, param in self.named_parameters() if param.grad is not None]
            optimizer.zero_grad()

            ret2['loss_clf'].backward(retain_graph=True)
            gradients_loss2_clf = [param.grad.clone() for name, param in self.named_parameters() if param.grad is not None]
            optimizer.zero_grad()

            ret1['loss'].backward(retain_graph=True)
            gradients_loss_full1 = [param.grad.clone() for name, param in self.named_parameters() if param.grad is not None]
            optimizer.zero_grad()

            ret2['loss'].backward(retain_graph=True)
            gradients_loss_full2 = [param.grad.clone() for name, param in self.named_parameters() if param.grad is not None]


            # check if gradients for imp_loss on ret1 and ret2 are the same
            # import pdb; pdb.set_trace()
            from itertools import zip_longest
            print(f"Are is_train=0 present in batch?: {torch.any(data['is_train'] == 0)}")
            # loop through zipeed gradients_loss1_imp and gradients_loss2_imp, along with zipped gradients_loss1_clf and gradients_loss2_clf
            for i, ((grad1_imp, grad2_imp), (grad1_clf, grad2_clf), (gradfull1, gradfull2)) in enumerate(zip_longest(zip(gradients_loss1_imp, gradients_loss2_imp), zip(gradients_loss1_clf, gradients_loss2_clf), zip(gradients_loss_full1, gradients_loss_full2))):
                layer = list(self.named_parameters())[i][0]
                print(f"Are all grads equal for layer {layer} with respect to imp_loss?: {torch.allclose(grad1_imp, grad2_imp, 1e-2)}")
                print(f"Are all grads equal for layer {layer} with respect to for clf_loss?: {torch.allclose(grad1_clf, grad2_clf, 1e-2)}")
                print(f"Are all grads equal for layer {layer} with respect to full_loss?: {torch.allclose(gradfull1, gradfull2, 1e-2)}")
                print()
            

        return ret1, ret2

In [5]:
ricardo_loader = get_loader(
    question=question,
    open_face=open_face,
    ratio_missing=ratio_missing,
    type_missing=type_missing,
    shuffle=True,
    batch_size=batch_size,
)

model = Model_brits_att(rnn_name=rnn_name)
opt = torch.optim.Adam(model.parameters(), lr = lr)

In [6]:
batch = next(iter(ricardo_loader))
batch = to_var(batch)
ret1 = model.forward(batch)



In [7]:
ret1

{'loss_imp': tensor(1.3762, grad_fn=<DivBackward0>),
 'loss_clf': tensor(0.0705, grad_fn=<MulBackward0>),
 'loss': tensor(3.0466, grad_fn=<AddBackward0>),
 'predictions': tensor([[0.4796],
         [0.4785],
         [0.4774],
         [0.4796],
         [0.4771],
         [0.4818],
         [0.4781],
         [0.4764],
         [0.4799],
         [0.4807],
         [0.4815],
         [0.4770],
         [0.4781],
         [0.4785],
         [0.4816],
         [0.4814],
         [0.4775],
         [0.4790],
         [0.4825],
         [0.4771],
         [0.4792],
         [0.4785],
         [0.4781],
         [0.4778],
         [0.4786],
         [0.4751],
         [0.4813],
         [0.4808],
         [0.4763],
         [0.4781],
         [0.4808],
         [0.4784]], grad_fn=<DivBackward0>),
 'imputations': tensor([[[ 0.0274,  0.2098, -0.9674,  ..., -0.1025,  0.2313, -0.9459],
          [ 0.0100,  0.2078, -0.9595,  ..., -0.0484,  0.2020, -0.9595],
          [ 0.1600,  0.2573, -0.9499,

In [8]:
from src.methods.brits.lightningmodule import BRITSLightningModule

model2 = BRITSLightningModule(
    rnn_hidden_size=RNN_HID_SIZE,
    question=question,
    open_face=open_face,
    ratio_missing=ratio_missing,
    type_missing=type_missing,
    rnn_name=rnn_name,
    lr=lr,
    batch_size=batch_size,
    repetitions=repetitions,
    epochs=epochs,
    experiment_name=experiment_name,
    pypots=False
)

class DummyDataModule:
    def __init__(self):
        self.data_info = {}
class DummyTrainer:
    def __init__(self):
        self.datamodule = DummyDataModule()

model2.trainer = DummyTrainer()
model2.trainer.datamodule.data_info = {
            'landmark': 136,
            'eye_gaze': 12,
            'action_unit': 14,
            'all': 162,
            'n_time_steps': SEQ_LEN,
            'n_features': N_SERIES,
            'n_classes': 2,
        }
model2.setup(stage='fit')

In [22]:
from copy import deepcopy
batch2 = deepcopy(batch)
batch2
# model2.forward(batch, training=True)

{'forward': {'values': tensor([[[ 0.0274,  0.2098, -0.9674,  ..., -0.1025,  0.2313, -0.9459],
           [ 0.0100,  0.2078, -0.9595,  ..., -0.0484,  0.2020, -0.9595],
           [ 0.1600,  0.2573, -0.9499,  ..., -0.3722,  0.4083, -0.8149],
           ...,
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
  
          [[ 0.0230,  0.2504, -0.9678,  ..., -0.1237,  0.0000, -0.9088],
           [ 0.0424,  0.2503, -0.9662,  ..., -0.1164,  0.4028, -0.9069],
           [ 0.0594,  0.2558, -0.9629,  ..., -0.1623,  0.3587, -0.9182],
           ...,
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
  
          [[ 0.0064,  0.2969, -0.9544,  ..., -0.2587,  0.3658, 

In [23]:
batch2['forward']['X'] = batch2['forward'].pop('values')
batch2['forward']['missing_mask'] = batch2['forward'].pop('masks')
batch2['backward']['X'] = batch2['backward'].pop('values')
batch2['backward']['missing_mask'] = batch2['backward'].pop('masks')
batch2['label'] = batch2['label'].long()



In [24]:
batch2['forward'].keys()

dict_keys(['deltas', 'evals', 'eval_masks', 'X', 'missing_mask'])

In [29]:
ret2 = model2.forward(batch2, training=False)

In [30]:
ret2.keys()

dict_keys(['imputed_data', 'classification_pred', 'consistency_loss', 'reconstruction_loss', 'loss', 'reconstruction', 'classification_loss', 'f_reconstruction', 'b_reconstruction'])

In [31]:
ret1.keys()

dict_keys(['loss_imp', 'loss_clf', 'loss', 'predictions', 'imputations', 'label', 'is_train', 'evals', 'eval_masks'])

In [35]:
ret1['imputations']

tensor([[[ 0.0274,  0.2098, -0.9674,  ..., -0.1025,  0.2313, -0.9459],
         [ 0.0100,  0.2078, -0.9595,  ..., -0.0484,  0.2020, -0.9595],
         [ 0.1600,  0.2573, -0.9499,  ..., -0.3722,  0.4083, -0.8149],
         ...,
         [ 0.0385, -0.0483, -0.0135,  ...,  0.0392,  0.0822, -0.0963],
         [ 0.0418, -0.0525, -0.0162,  ...,  0.0369,  0.0830, -0.0941],
         [ 0.0466, -0.0609, -0.0252,  ...,  0.0364,  0.0835, -0.0901]],

        [[ 0.0230,  0.2504, -0.9678,  ..., -0.1237, -0.0829, -0.9088],
         [ 0.0424,  0.2503, -0.9662,  ..., -0.1164,  0.4028, -0.9069],
         [ 0.0594,  0.2558, -0.9629,  ..., -0.1623,  0.3587, -0.9182],
         ...,
         [ 0.0364, -0.0477, -0.0175,  ...,  0.0417,  0.0839, -0.0963],
         [ 0.0401, -0.0520, -0.0206,  ...,  0.0397,  0.0842, -0.0944],
         [ 0.0466, -0.0609, -0.0252,  ...,  0.0364,  0.0835,  0.0354]],

        [[ 0.0064,  0.2969, -0.9544,  ..., -0.2587,  0.3658, -0.8937],
         [-0.0265,  0.2859, -0.9572,  ..., -0

In [36]:
ret2['imputed_data']

tensor([[[ 0.0274,  0.2098, -0.9674,  ..., -0.1025,  0.2313, -0.9459],
         [ 0.0100,  0.2078, -0.9595,  ..., -0.0484,  0.2020, -0.9595],
         [ 0.1600,  0.2573, -0.9499,  ..., -0.3722,  0.4083, -0.8149],
         ...,
         [ 0.0395,  0.0706,  0.0614,  ...,  0.0154,  0.0324,  0.1001],
         [ 0.0391,  0.0695,  0.0585,  ...,  0.0166,  0.0314,  0.1000],
         [ 0.0367,  0.0661,  0.0536,  ...,  0.0194,  0.0320,  0.1018]],

        [[ 0.0230,  0.2504, -0.9678,  ..., -0.1237, -0.0259, -0.9088],
         [ 0.0424,  0.2503, -0.9662,  ..., -0.1164,  0.4028, -0.9069],
         [ 0.0594,  0.2558, -0.9629,  ..., -0.1623,  0.3587, -0.9182],
         ...,
         [ 0.0383,  0.0702,  0.0632,  ...,  0.0156,  0.0336,  0.1007],
         [ 0.0381,  0.0695,  0.0599,  ...,  0.0165,  0.0323,  0.1003],
         [ 0.0367,  0.0661,  0.0536,  ...,  0.0194,  0.0320,  0.1624]],

        [[ 0.0064,  0.2969, -0.9544,  ..., -0.2587,  0.3658, -0.8937],
         [-0.0265,  0.2859, -0.9572,  ..., -0

In [37]:
model

Model_brits_att(
  (rits_f): Model_rits_att(
    (rnn_cell): LSTMCell(24, 32)
    (temp_decay_h): TemporalDecay()
    (temp_decay_x): TemporalDecay()
    (hist_reg): Linear(in_features=32, out_features=12, bias=True)
    (feat_reg): FeatureRegression()
    (weight_combine): Linear(in_features=24, out_features=12, bias=True)
    (dropout): Dropout(p=0.25, inplace=False)
    (out): Linear(in_features=960, out_features=1, bias=True)
    (W_s1): Linear(in_features=32, out_features=350, bias=True)
    (W_s2): Linear(in_features=350, out_features=30, bias=True)
  )
  (rits_b): Model_rits_att(
    (rnn_cell): LSTMCell(24, 32)
    (temp_decay_h): TemporalDecay()
    (temp_decay_x): TemporalDecay()
    (hist_reg): Linear(in_features=32, out_features=12, bias=True)
    (feat_reg): FeatureRegression()
    (weight_combine): Linear(in_features=24, out_features=12, bias=True)
    (dropout): Dropout(p=0.25, inplace=False)
    (out): Linear(in_features=960, out_features=1, bias=True)
    (W_s1): Linea

In [38]:
model2

BRITSLightningModule(
  (model): MultiTaskBRITS(
    (model): MyBackboneBRITS(
      (rits_f): MyBackboneRITS(
        (rnn_cell): LSTMCell(24, 32)
        (temp_decay_h): TemporalDecay()
        (temp_decay_x): TemporalDecay()
        (hist_reg): Linear(in_features=32, out_features=12, bias=True)
        (feat_reg): FeatureRegression()
        (combining_weight): Linear(in_features=24, out_features=12, bias=True)
      )
      (rits_b): MyBackboneRITS(
        (rnn_cell): LSTMCell(24, 32)
        (temp_decay_h): TemporalDecay()
        (temp_decay_x): TemporalDecay()
        (hist_reg): Linear(in_features=32, out_features=12, bias=True)
        (feat_reg): FeatureRegression()
        (combining_weight): Linear(in_features=24, out_features=12, bias=True)
      )
    )
    (f_classifier): None
    (b_classifier): None
    (out): Linear(in_features=960, out_features=2, bias=True)
    (W_s1): Linear(in_features=32, out_features=30, bias=True)
    (W_s2): Linear(in_features=30, out_feature