In [None]:
import scipy.io
from scipy.sparse import coo_matrix
import pandas as pd
import matplotlib.pyplot as plt
import time
import cv2
import numpy as np

from sys import getsizeof

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.sampler import SubsetRandomSampler

from MyDataset import MyDataset
import random

print(torch.__version__)

In [None]:
torch.set_default_dtype(torch.float64)

peakdata = scipy.io.loadmat('peakdata.mat', squeeze_me=True)
peakdata = peakdata['data']

In [None]:
mzscale = 1
rtscale = 1
intscale = 1/3

mzs = peakdata['mz'].item()*mzscale
print("mz")
print(peakdata['mz'].item())
print("\n")

rts = peakdata['rt'].item()*rtscale
print("rt")
print(peakdata['rt'].item())
print("\n")

ints = peakdata['int'].item()**(intscale)
print("int")
print(peakdata['int'].item())
print("\n")

indices = peakdata['peak_index'].item()
print("peak indices")
print(peakdata['peak_index'].item())
print("\n")


In [None]:
print("Number of 'scans':", len(indices))
print("Number of unique peaks:",len(np.unique(indices)))
print("Number of scans containing peaks:",sum(indices != 0))

In [None]:
dist_mz = 1
dist_rt = 0.7
mzH = int(np.ceil(np.max(mzs)))
print('mzH: ' + str(mzH))
mzL = int(np.floor(np.min(mzs)))
print('mzL: ' + str(mzL))
rtH = np.ceil(np.max(rts)*100)/100
print('rtH: ' + str(rtH))
rtL = np.floor(np.min(rts)*100)/100
print('rtL: ' + str(rtL))
mz_overlap = 0.10
rt_overlap = 0.10
mzscale = 2*1e2
rt_range = np.linspace(rtL, rtH, int((rtH - rtL)/dist_rt)).tolist()
print(rt_range)
print(len(rt_range))
mz_range = range(mzL, mzH,dist_mz)
print(mz_range)
print(len(mz_range))

In [None]:
# setting all peaks to 1 in preparation to create mask label

print(type(indices))
print(indices.dtype)

indices[indices != 0] = 1

indices=indices.astype('float32')
print(type(indices))
print(indices.dtype)


In [None]:
ii_list = []

imagesTrue = []
masksTrue = []
labelsTrue = []

imagesFalse = []
masksFalse = []
labelsFalse = []

dataCount = 0
peakCount = 0

#pre_pool = nn.MaxPool2d(kernel_size = 2, stride = [8,4])

for mz in mz_range:
    #list of indexes where mz is withing a specific range
    i_mzs = np.where((mzs > mz) & (mzs < (mz + dist_mz + mz_overlap)))[0].tolist()
    t = time.time()
    for rt in rt_range[:-1]:
        i_rts = np.where((rts > rt) & (rts < (rt + dist_rt + rt_overlap)))[0].tolist()
        ii = list(set(i_mzs).intersection(i_rts))
        #ii_list.append(ii)
        
        ## not needed?
        rts_int = (np.rint((rts[ii] - np.min(rts[ii]))/np.mean(np.diff(np.unique(rts[ii]))))).astype(int)
        mzs_int = (np.rint(mzs[ii]*mzscale) - np.rint(np.min(mzs[ii])*mzscale)).astype(int)

        
        mat = torch.sparse_coo_tensor(indices = [np.array(mzs_int), np.array(rts_int)], 
                                      values = np.array(ints[ii]), 
                                      size=[256, 256])
        
        mask = torch.sparse_coo_tensor(indices = [np.array(mzs_int), np.array(rts_int)], 
                                      values = np.array(indices[ii]), 
                                      size=[256, 256])
        

        
        mat=torch.unsqueeze(mat,axis=0)
        mask=torch.unsqueeze(mask, axis=0)
        
        
        #mat=mat.to_dense()
        #mask=mask.to_dense()
        #mat=pre_pool(mat)
        #label=pre_pool(label)
        
    
        print(type(mat))
        print(mat.shape)
        print(type(mask))
        print(mask.shape)
         
        dataCount += 1
        
        appended = False
        for ind in indices[ii]:
            if ind > 0:
                peakCount += 1
                ii_list.append(ii)
                imagesTrue.append(mat)
                masksTrue.append(mask)
                labelsTrue.append(True)
                appended = True
                break
                
        if appended is False:
            imagesFalse.append(mat)
            masksFalse.append(mask)
            labelsFalse.append(False)
                     
            
        print('\n' + str(dataCount-1))    
        print('used by sparse matrix: ' + str(getsizeof(mat)/1000000) + ' MB' )        
        print(np.min(mzs_int), np.max(mzs_int), len(np.unique(mzs_int)), len(mzs_int),
              np.min(rts_int), np.max(rts_int), len(np.unique(rts_int)), len(rts_int),)
        print("images with labeled peak(s):" + str((peakCount/dataCount)*100) + "%")
        #print('peak:' + str(labels[-1]))
#    if (len(imagesTrue)>200):
#        break

In [None]:
np.random.seed(0)
size = max(len(imagesTrue),len(imagesFalse))
print(size)

np.random.seed(0)
np.random.shuffle(ii_list)
print(len(ii_list))

np.random.seed(0)
np.random.shuffle(imagesTrue)
imagesTrue=imagesTrue[:size]
print(len(imagesTrue))

np.random.seed(0)
np.random.shuffle(masksTrue)
masksTrue=masksTrue[:size]
print(len(masksTrue))

np.random.seed(0)
np.random.shuffle(labelsTrue)
labelsTrue=labelsTrue[:size]
print(len(labelsTrue))


#np.random.seed(0)
#np.random.shuffle(imagesFalse)
#imagesFalse=imagesFalse[:size]
#print(len(imagesFalse))

#np.random.seed(0)
#np.random.shuffle(masksFalse)
#masksFalse=masksFalse[:size]
#print(len(masksFalse))

#np.random.seed(0)
#np.random.shuffle(labelsFalse)
#labelsFalse=labelsFalse[:size]
#print(len(labelsFalse))

#images = imagesTrue #+ imagesFalse
#np.random.seed(0)
#np.random.shuffle(images)
#print(len(images))

#masks = masksTrue #+ masksFalse
#np.random.seed(0)
#np.random.shuffle(masks)
#print(len(masks))

#labels = labelsTrue #+ labelsFalse
#np.random.seed(0)
#np.random.shuffle(labels)
#print(len(labels))


In [None]:
torch.save(images, '256x256_images_100_percent.pt')
torch.save(masks, '256x256_masks_100_percent.pt')
torch.save(labels, '256x256_labels_100_percent.pt')

In [None]:
images = torch.load('256x256_images_100_percent.pt')
masks = torch.load('256x256_masks_100_percent.pt')
labels = torch.load('256x256_labels_100_percent.pt')

In [None]:
image_nr = 178

ii = ii_list[image_nr]

int_ind = [int(item) for item in indices[ii] != 0]


fig, ax = plt.subplots()
fig.set_dpi(200)
plt.scatter(rts[ii], mzs[ii], s=(0.05*ints[ii])**1)
ax.set_title('Raw data')

fig, ax = plt.subplots()
fig.set_dpi(200)
plt.scatter(rts[ii], mzs[ii], s=int_ind)
ax.set_title('Annotated peaks')


In [None]:
#image_nr = 0

myTarget=masksTrue[image_nr]
myTarget=myTarget.to_dense()

myIntensities=imagesTrue[image_nr]
myIntensities=myIntensities.to_dense()

print(labels[image_nr])

#print(myTarget[myTarget != 0])
#print(myIntensities[myIntensities != 0])

myTarget=myTarget.flatten()
print(myTarget.size())

myIntensities=myIntensities.flatten()
print(myIntensities.size())

ls=np.linspace(1,256,256) 
x_axis=np.tile(ls,256) 
print(x_axis.shape)
print(x_axis)

ls=np.linspace(1,256,256) 
y_axis=np.repeat(ls,256) 
print(y_axis.shape)
print(y_axis)

fig, ax = plt.subplots()
fig.set_dpi(200)
plt.scatter(x_axis, y_axis, s=(0.05*myIntensities)**2)
ax.set_title('Intensities')

fig, ax = plt.subplots()
fig.set_dpi(200)
plt.scatter(x_axis, y_axis, s=myTarget)
ax.set_title('Labels')