In [1]:
import pickle
import torch
import os
import numpy as np
import pandas as pd
from pprint import pprint
from collections import Counter
from tqdm import tqdm

Checking for equality between two tuples that include tensors cannot take advantage of the == operator. Check the following example:


In [16]:
import pickle
import torch 
import os

directo = '/homes/gws/jacopo/trelium/SeattleFluStudy/debugbatch/2'
filename = 'debug_batch_2.p'

with open(os.path.join(directo, filename), 'rb') as fp:
    data = pickle.load(fp)
    obs1 = (data['inputs_embeds'][0], data['label'][0], data['participant_id'][0], data['id'][0], data['end_date_str'][0])
    obs2 = (data['inputs_embeds'][1], data['label'][1], data['participant_id'][1], data['id'][1], data['end_date_str'][1])
    obs3 = (data['inputs_embeds'][0], data['label'][0], data['participant_id'][0], data['id'][0], data['end_date_str'][0])    
    
    print(hash(obs1) == hash(obs2))  #False 
    print(hash(obs1) == hash(obs3))  #also False! 'in' set operator cannot be used
        
    print(obs1 == obs2)
    print(obs1 == obs3)
    
    #https://stackoverflow.com/questions/8705378/pythons-in-set-operator

False
False


RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [19]:
obss = set()
obss.add(obs2)
obss.add(obs3) #don't add obs1



print(obs1 in obss)
print(hash(obs1) == hash(obs3)) 

False
False


In [None]:
"""
count = Counter()
for epoch in allresults.keys():
    count.update(allresults[epoch]['files_in_epoch'])

print(len(count))          
"""

We need to define a custom function to check for equality item-to-item in our examples.

In [3]:
def check_equality(observation, other):
    """Returns true if two given example tuples contain tensors with same shape and numbers 
    as well as equal elements in other fields."""
    equality = []
    for i in range(len(observation)):
        if torch.is_tensor(observation[i]):
            #if torch.eq(observation[i], other[i]).sum() == torch.numel(observation[i]):
            if torch.equal(observation[i], other[i]):
                equality.append(True)
        elif observation[i] == other[i]:
            equality.append(True)
    if len(equality) == 5:
        return True
    else:
        return False
    
class HashTensorWrapper():
    """provides rudimental hashing support for tensors
    https://discuss.pytorch.org/t/how-to-put-tensors-in-a-set/123836/6
    """
    def __init__(self, tensor):
        self.tensor = tensor
        self.hashcrap = torch.arange(self.tensor.numel(), device=self.tensor.device).reshape(self.tensor.size())

    def __hash__(self):
        if self.hashcrap.size() != self.tensor.size():
            self.hashcrap = torch.arange(self.tensor.numel(), device=self.tensor.device).reshape(self.tensor.size())
        return hash(torch.sum(self.tensor*self.hashcrap))

    def __eq__(self, other):
        return torch.all(self.tensor == other.tensor)

Let's now chech wheter there are epochs that contain duplicate items (examples).  

In [None]:
allresults = {}

base_dir = '/homes/gws/jacopo/trelium/SeattleFluStudy/debugbatch'
for (root,direcs,files) in os.walk(base_dir):
    dirs = direcs
    break

for directory in dirs: #inspect each epoch (one per directory) 
    totalepochsize = 0
    allepoch = set()
    duplicates = []
    for filename in os.listdir(os.path.join(base_dir, directory)): #inspect each batch
        with open(os.path.join(base_dir,directory, filename), 'rb') as fp:
            data = pickle.load(fp)
            
            batchsize = len(data['inputs_embeds']) #each embedding: torch.Size([5760, 8])
            totalepochsize += batchsize
            
            for i in range(batchsize):
                long = tuple() #transform data in different format
                for key in data:
                    long = long + (data[key][i],)
                
                if long in allepoch: #TODO this does not work, check each field for equlity 
                    duplicates.append(long)
                else:
                    allepoch.add(long)

    allresults[directory] = {'totalepochsize' : totalepochsize, 
                            'duplicates' : duplicates,
                            'files_in_epoch' : allepoch}

In [None]:
#checks whether some epochs contain duplicate examples. 
counts = dict()
for epoch in allresults.keys():
    repetitions = 0
    for observationout in allresults[epoch]['files_in_epoch']: #within an epoch
        for observationin in allresults[epoch]['files_in_epoch']: #within an epoch
            if check_equality(observationin, observationout):
                repetitions +=1
    repetitions -= len(allresults[epoch]['files_in_epoch'])
    counts[epoch] = repetitions
    print(f'done epoch {epoch}')


for i in allresults
    print ('epoch {} contains {} total examples, of which {} are repeated'.format(i,allresults[i]['totalepochsize'],counts[i])

Let's take advantage of the previously defined HashTensorWrapper class 

In [13]:
allresults = {}
device = torch.device('cpu') #where to deserialize tensors 


base_dir = '/homes/gws/jacopo/trelium/SeattleFluStudy/debugbatch_fulldata'
for (root,direcs,files) in os.walk(base_dir):
    dirs = direcs
    break

for directory in dirs: #inspect each epoch (one per directory) 
    totalepochsize = 0
    allepoch = set()
    duplicates = tuple()
    for filename in tqdm(os.listdir(os.path.join(base_dir, directory))): #inspect each batch
        with open(os.path.join(base_dir,directory, filename), 'rb') as fp:
            #data = pickle.load(fp)
            data = torch.load(fp, map_location = device, pickle_module = pickle)
            
            batchsize = len(data['inputs_embeds']) #each embedding: torch.Size([5760, 8])
            totalepochsize += batchsize
            
            for i in range(batchsize):
                long = tuple() #transform data in different format
                for key in data:
                    if torch.is_tensor(data[key][i]):
                        #data[key][i].to(device)  
                        newdatapoint = HashTensorWrapper(data[key][i])
                        long = long + (newdatapoint,)
                    else:
                        long = long + (data[key][i],)
                
                if long in allepoch: 
                    duplicates = duplicates + (long,)
                else:
                    allepoch.add(long)

    allresults[directory] = {'totalepochsize' : totalepochsize, 
                            'duplicates' : duplicates,
                            'files_in_epoch' : allepoch}
    

  0%|                 | 0/217 [00:00<?, ?it/s]


RuntimeError: Invalid magic number; corrupt file?

In [6]:

import os

from src.utils import get_unused_gpus
print('done')

free_devices = get_unused_gpus()
print(free_devices)
print(os.environ.get("CUDA_VISIBLE_DEVICES"))



devices = free_devices[0]
#os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(devices)

print(os.environ.get("CUDA_VISIBLE_DEVICES"))



done
['0', '1', '2', '3']
None
0
