In [3]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
# from Utils.errors import *
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

In [4]:
labelmap = {'real': 0, 'fake': 1}

In [5]:
spec_path = '/home/itdfh/data/dfdc-subset/train_spectrograms_part-5'
xcep_path = '/home/itdfh/data/dfdc-subset/train_xception_part-5'    

In [8]:
def tensor_file_lists(spec_path, xcep_path, max_files=None, perc=.9):
    spec_files_train, xcep_files_train = [], []
    spec_files_val, xcep_files_val = [], []
    
    for label in ['real', 'fake']:
        train_files = []
        val_files = []
        
        all_files = os.listdir(os.path.join(spec_path, label))
        
        for i, p in enumerate(all_files):
            base_dir = os.path.join(label, p)
            full_base_dir = os.path.join(spec_path, base_dir)
            if i < len(all_files) * .9:
                train_files.extend([os.path.join(base_dir, p) for p in os.listdir(full_base_dir)])
            else:
                val_files.extend([os.path.join(base_dir, p) for p in os.listdir(full_base_dir)])
        
        spec_files_train.extend([(os.path.join(spec_path, p), labelmap[label]) for p in train_files if p[-5:] == '24.pt'])
        xcep_files_train.extend([(os.path.join(xcep_path, p), labelmap[label]) for p in train_files if p[-5:] == '24.pt'])
        
        spec_files_val.extend([(os.path.join(spec_path, p), labelmap[label]) for p in val_files if p[-5:] == '24.pt'])
        xcep_files_val.extend([(os.path.join(xcep_path, p), labelmap[label]) for p in val_files if p[-5:] == '24.pt'])
    
    return spec_files_train, xcep_files_train, spec_files_val, xcep_files_val

In [9]:
spec_files_train, xcep_files_train, spec_files_val, xcep_files_val = tensor_file_lists(spec_path, xcep_path)

In [11]:
class FrimagenetDataset(Dataset):
    '''
    FrimageNet data set for concatenating XceptionNet Features and Spectrogram features
    '''
    def __init__(self, spec_files, xcep_files, seq_size=24, max_spec_size=700):
        """
        Args:
            spectrogram_folder (string): Path to the csv file with annotations.
            xception_features_folder (string): Directory with all the images.
        """
        self.max_spec_size = max_spec_size
        self.seq_size = seq_size
        
        self.spec_files, self.xcep_files = spec_files, xcep_files
        

    def __len__(self):
        return len(self.spec_files)

    def __getitem__(self, idx):
        sf, label = self.spec_files[idx]
        xf, label = self.xcep_files[idx]
        
        # loading spec_feats with 0 padding
        spec_feats = torch.zeros((self.seq_size, self.max_spec_size))
        specs = torch.load(sf, map_location=torch.device('cpu'))[:,:self.max_spec_size]
        spec_feats[:,:specs.shape[-1]] = specs
        
        xcep_feats = torch.load(xf, map_location=torch.device('cpu'))
        x = torch.cat((xcep_feats, spec_feats), dim=-1)
        label = torch.tensor(label).long()
        return x, label


In [12]:
trainset = FrimagenetDataset(spec_files_train, xcep_files_train)
valset = FrimagenetDataset(spec_files_val, xcep_files_val)

In [17]:
trainloader = DataLoader(trainset, batch_size=50, shuffle=True, num_workers=0, drop_last=True)
valloader = DataLoader(valset, batch_size=50, shuffle=True, num_workers=0, drop_last=True)

In [18]:
for xs, labels in valloader:
    print(xs.shape, labels.shape)

torch.Size([50, 24, 2748]) torch.Size([50])
torch.Size([50, 24, 2748]) torch.Size([50])
torch.Size([50, 24, 2748]) torch.Size([50])
torch.Size([50, 24, 2748]) torch.Size([50])
torch.Size([50, 24, 2748]) torch.Size([50])
torch.Size([50, 24, 2748]) torch.Size([50])
torch.Size([50, 24, 2748]) torch.Size([50])


KeyboardInterrupt: 

In [58]:
torch.tensor(0).long()

tensor(0)

In [None]:
data = []
for label in ['real', 'fake']:
    for sf, xf, in zip(spec_files[label], xcep_files[label]):
        spec_feats = torch.load(sf, map_location=torch.device('cpu'))
        xcep_feats = torch.load(xf, map_location=torch.device('cpu'))
        all_feats = torch.cat((xcep_feats, spec_feats), dim=-1)
        lab = 0 if label == 'real' else 1
        label = torch.tensor(lab).long()
        data.append((all_feats, label))

In [None]:
data

In [31]:
# checking files
counter = 0
for sf, xf, in zip(spec_files, xcep_files):
    if not os.path.exists(xf):
        counter += 1
        print(counter)

In [None]:
data = FrimagenetDataset(spectrogram, xception)

mpbddoyjda-000-24.pt != mpbcfvbvax-000-24.pt 
mpbddoyjda-001-24.pt != mpbcfvbvax-001-24.pt 
mpbddoyjda-002-24.pt != mpbcfvbvax-002-24.pt 
mpbddoyjda-003-24.pt != mpbcfvbvax-003-24.pt 
mpbddoyjda-004-24.pt != mpbcfvbvax-004-24.pt 
mpbddoyjda-005-24.pt != mpbcfvbvax-005-24.pt 
mpbddoyjda-006-24.pt != mpbcfvbvax-006-24.pt 
mpbddoyjda-007-24.pt != mpbcfvbvax-007-24.pt 
mpbddoyjda-008-24.pt != mpbcfvbvax-008-24.pt 
mpbddoyjda-009-24.pt != mpbcfvbvax-009-24.pt 


In [None]:
count_real = 0
for i, sample in enumerate(data):
    if sample[1].tolist() == 1:
        count_real += 1
    print(sample)
print('Real', count_real)
print('Total', len(data))