# Load from DN3 (original BENDR)

In [1]:
import dn3_ext
import torch 

In [2]:
conv_encoder = dn3_ext.ConvEncoderBENDR(in_features=128, encoder_h=512)
print(sum(p.numel() for p in conv_encoder.parameters() if p.requires_grad))
t1 = torch.rand(16, 128, 1024)
conv_encoder(t1).shape

4137984




torch.Size([16, 512, 11])

In [5]:
contextualizer = dn3_ext.BENDRContextualizer(in_features=128)
print(sum(p.numel() for p in contextualizer.parameters() if p.requires_grad))
t2 = torch.rand(16, 128, 11)
contextualizer(t2).shape

23783225


torch.Size([16, 128, 12])

In [19]:
model = dn3_ext.BendingCollegeWav2Vec(conv_encoder, contextualizer)

In [22]:
logits, z, mask = model.forward(t1)
logits.shape, z.shape, mask.shape

(torch.Size([176, 101]), torch.Size([16, 512, 11]), torch.Size([16, 11]))

In [8]:
encoder_layer = torch.nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
src = torch.rand(32, 10, 512)
out = encoder_layer(src)
out.shape

torch.Size([32, 10, 512])

In [12]:
def _generate_negatives(self, z, num_negatives=20):
        """Generate negative samples to compare each sequence location against"""
        batch_size, feat, full_len = z.shape
        z_k = z.permute([0, 2, 1]).reshape(-1, feat)
        with torch.no_grad():
            # candidates = torch.arange(full_len).unsqueeze(-1).expand(-1, self.num_negatives).flatten()
            negative_inds = torch.randint(0, full_len-1, size=(batch_size, full_len * num_negatives))
            # From wav2vec 2.0 implementation, I don't understand
            # negative_inds[negative_inds >= candidates] += 1

            for i in range(1, batch_size):
                negative_inds[i] += i * full_len

        z_k = z_k[negative_inds.view(-1)].view(batch_size, full_len, num_negatives, feat)
        return z_k, negative_inds

In [13]:
z = torch.rand(16, 512, 10)
z_k, negative_inds = _generate_negatives(None, z)
z_k.shape, negative_inds.shape

(torch.Size([16, 10, 20, 512]), torch.Size([16, 200]))

# Dataloader test

In [1]:
import datasets
import torch 

In [2]:
dataset = datasets.MultiParticipantDataset(
    root_dir="/itet-stor/wolflu/net_scratch/projects/EEGEyeNet_experimental/data/stream"
)

In [10]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

In [4]:
import pandas as pd 

df = pd.read_csv("/itet-stor/wolflu/net_scratch/projects/EEGEyeNet_experimental/data/stream/zuco/synch_min/YDG/YDG_NR1_EEG/substream_1.csv")
df.head()

Unnamed: 0.1,Unnamed: 0,channel_0,channel_1,channel_2,channel_3,channel_4,channel_5,channel_6,channel_7,channel_8,...,channel_122,channel_123,channel_124,channel_125,channel_126,channel_127,x,y,latency,event
0,0,56.257469,18.051191,26.662785,29.870787,22.915567,20.315931,9.134151,92.910492,56.556118,...,33.492527,24.869997,2.918007,35.776516,9.448587,-57.668213,203.838272,587.647766,1022,L_sa
1,1,57.057518,20.887222,22.547913,39.934853,19.913111,18.151897,7.447501,90.984131,55.496334,...,37.711662,26.393988,9.511328,38.023495,8.503574,-58.431644,240.135223,573.144897,1023,L_sa
2,2,59.40366,21.838907,28.809597,39.751484,18.69854,18.242762,7.906695,92.435623,57.216488,...,41.699432,22.582708,15.671942,40.256069,6.635967,-58.401382,335.638641,561.219055,1024,L_sa
3,3,58.965668,17.897776,30.020729,35.211407,20.898821,19.694096,8.783145,88.797241,56.081818,...,39.327866,23.335234,13.596437,38.504345,6.441703,-57.743423,429.658691,547.339172,1025,L_sa
4,4,58.359009,17.951101,31.22024,36.926132,21.192352,19.414371,8.368823,87.650238,55.861965,...,36.803974,27.312857,12.131395,37.479729,7.741166,-54.807995,517.024902,526.905518,1026,L_sa
