In [1]:
import pickle

In [2]:
#Import necessary libraries

import numpy as np
import pandas as pd
import seal
from seal import *
import sourmash as smsh
import time
import matplotlib.pyplot as plt

In [3]:
#Load 8000 labelled samples, comprising training and test data
def load_data():
    with open("Challenge.fa", "r") as f:
        data = f.readlines()

    labels = []
    sequences = []
    lengths = []
    for k in range(len(data)):
        if k % 2 == 0:
            labels.append(data[k])
        else:
            seq = data[k].strip()
            lengths.append(len(seq))
            sequences.append(seq)

    # uniformize lengths by filling in with N's
    #max_length = max(lengths)
    #for i in range(len(sequences)):
        #padding_size = max_length - len(sequences[i])
        #for j in range(padding_size):
            #sequences[i] += "N"


    types = [">B.1.526", ">B.1.1.7", ">B.1.427", ">P.1"]

    dataframe = []

    for i in range(len(labels)):
        entry = []
        # 2021/08/02: re-replaced use of match-case (Python 3.10) for backwards compatibility
        for j in range(len(types)):
            if labels[i].startswith(types[j]):
                entry.append(j)
                virus_number = labels[i].split("_")[1].strip()
                entry.append(virus_number)
                entry.append(sequences[i])
                break

            if j == 3:
                raise "Bad entry"

        dataframe.append(entry)

    return dataframe

In [4]:
data = load_data()

In [5]:
data_df = pd.DataFrame(data)
labels = data_df[:][0]

In [6]:
base_dict = {0:'A',1:'C',2:'G',3:'T'}

In [7]:
#Key preprocessing step:
#Replace all non-ACTG characters with an ACTG chosen uniformly at random.
start = time.time()
data_Nrand = []

for i in range(len(data)):
    string_mod = ''
    for j in range(len(data[i][2])):
        if data[i][2][j]=='A' or data[i][2][j]=='C' or data[i][2][j]=='G' or data[i][2][j]=='T':
            string_mod += data[i][2][j]
        else:
            string_mod+= base_dict[np.random.randint(0,4)]
    data_Nrand.append([data[i][0],data[i][1],string_mod])
    
end = time.time()
print(f'Time to Replace unknowns: {(end-start):.3f}s')

Time to Replace unknowns: 126.632s


In [8]:
#These are the sketch parameters that I settled on. Form sketches of all samples.
start = time.time()
sketches = []
N = 5000
K = 33

for i in range(len(data_Nrand)):
    mh = smsh.MinHash(n=N,ksize=K)
    mh.add_sequence(data_Nrand[i][2])
    sketches.append(mh)
    
end = time.time()
print(f'Time to form sketches: {(end-start):.3f}s')

Time to form sketches: 37.683s


In [8]:
#Set aside 1000 samples as test set held by data owner.
s= pd.Series(np.arange(8000))
test_samples = s.sample(n=1000, random_state = 101)
test_samples

7289    7289
5447    5447
5159    5159
3020    3020
1866    1866
        ... 
4280    4280
2146    2146
6111    6111
1724    1724
6194    6194
Length: 1000, dtype: int64

In [10]:
sketches_df = pd.DataFrame(sketches)
test_sketches = sketches_df.iloc[list(test_samples)]

In [9]:
test_labels = labels[list(test_samples)]
test_labels

7289    0
5447    3
5159    3
3020    1
1866    2
       ..
4280    3
2146    1
6111    0
1724    2
6194    0
Name: 0, Length: 1000, dtype: int64

In [11]:
#Save test labels for Data Owner to access in Step 4.
#In real situation these would not be accessible to Model Owner.

test_labels.to_pickle('data_owner_labels')

In [11]:
test_indices = list(test_samples)
test_indices.sort(reverse=True)

In [12]:
#Hold test sketches aside from training set.
for i in range(len(test_indices)):
    sketches.pop(test_indices[i])

In [13]:
#Remove test labels
for i in range(len(test_indices)):
    labels.pop(test_indices[i])

In [14]:
#MODEL OWNER (TRAINING)
#Compute full matrix of Jaccard similarities. Takes a long time.
start = time.time()
jacc_sim = np.zeros((7000,7000))

for i in range(len(sketches)):
    #print(i)
    for j in range(i+1,len(sketches)):
        jacc_sim[i,j] = round(sketches[i].jaccard(sketches[j]),4)
        
end = time.time()
print(f'Time to compute similarities between all training sketches: {(end-start):.3f}s')

Time to compute similarities between all training sketches: 4610.299s


In [16]:
#Turn Jaccard similarities into matrix of distances.
start = time.time()
dist_adj = np.zeros((7000,7000))

for i in range(7000):
    #print(i)
    for j in range(i+1,7000):
        dist_adj[i,j] = -np.log(2*jacc_sim[i,j])+np.log(1+jacc_sim[i,j])
        dist_adj[j,i] = dist_adj[i,j]
        
end = time.time()
print(f'Time to compute training distances: {(end-start):.3f}s')

Time to compute training distances: 102.545s


In [17]:
dist_adj_df = pd.DataFrame(dist_adj)

In [18]:
from sklearn.model_selection import train_test_split

In [19]:
#Model is based on distances to 12 randomly chosen "anchor" samples
s = pd.Series(np.arange(7000))
anchors = s.sample(n=12,random_state=5)

In [20]:
anchor_indices = list(anchors)

In [21]:
#Split Model Owner's samples into his own training and test set for validation
X_train, X_test, y_train, y_test = train_test_split(dist_adj_df[anchor_indices], np.ravel(labels), test_size=0.15, random_state=5)

In [22]:
from sklearn.linear_model import LogisticRegression

In [23]:
#MODEL OWNER (TRAINING)
#Fit a logistic regression model based on distances to anchors.
logmodel = LogisticRegression(fit_intercept=False)
logmodel.fit(X_train,y_train)

LogisticRegression(fit_intercept=False)

In [24]:
predictions = logmodel.predict(X_test)

In [25]:
#The 4*12 matrix of coefficients of the model. 
#This is the IP the Model Owner wishes to protect
logmodel.coef_

array([[  5.34780467, -15.0303433 ,   2.7803143 ,   2.02359664,
        -14.47587293,   3.09000255,   5.26324089,   7.87415749,
          3.2890678 ,   3.5120713 ,   1.98378768, -15.17930856],
       [  2.55429896,   3.36550033,   2.51032959,   2.74504425,
          2.78878523,   1.22940156,   3.92857415, -26.31853051,
          1.24554117,   2.67423688,   2.00455739,   3.26833358],
       [  9.04437978,   7.16669411,  -9.40140473,  -8.66478166,
          6.56027444,  -8.09806074,   8.2587428 ,  10.8551519 ,
         -8.94915775,  -8.6965517 ,  -8.37131341,   7.1085428 ],
       [-16.94648342,   4.49814887,   4.11076085,   3.89614077,
          5.12681326,   3.77865663, -17.45055784,   7.58922112,
          4.41454878,   2.51024351,   4.38296834,   4.80243218]])

In [26]:
from sklearn.metrics import classification_report,confusion_matrix

In [27]:
#Validate the model on a test set (not the Data Owner's test set)
print(confusion_matrix(y_test,predictions))

[[263   0   0   1]
 [  0 265   0   0]
 [  0   0 252   0]
 [  0   0   0 269]]


In [28]:
#Save model data for Model Owner's use in Step 3

logmodel.classes_.dump('logmodel_classes.dump')
logmodel.intercept_.dump('logmodel_intercept.dump')
logmodel.coef_.dump('logmodel_coef.dump')


#Save test sketches for Data Owner in Step 2
#In real situation, Data Owner would hold these from the start

pickle.dump(test_sketches, open('test_sketches.dump','wb'))

#Data below isn't used again.

#pickle.dump(sketches, open('sketches.dump','wb'))
#pickle.dump(anchor_indices, open('anchor_indices.dump', 'wb'))
#data_df.to_pickle('data_df.dump')
#test_samples.to_pickle('test_samples.dump')

In [35]:
anchor_sketches = sketches_df.iloc[anchor_indices]
anchor_sketches

Unnamed: 0,0
4187,<sourmash.minhash.MinHash object at 0x7f864431...
5538,<sourmash.minhash.MinHash object at 0x7f86442d...
141,<sourmash.minhash.MinHash object at 0x7f864458...
9,<sourmash.minhash.MinHash object at 0x7f86704b...
5350,<sourmash.minhash.MinHash object at 0x7f86442d...
27,<sourmash.minhash.MinHash object at 0x7f86445b...
4438,<sourmash.minhash.MinHash object at 0x7f864432...
3291,<sourmash.minhash.MinHash object at 0x7f864436...
735,<sourmash.minhash.MinHash object at 0x7f86445a...
960,<sourmash.minhash.MinHash object at 0x7f86445b...


In [36]:
test_sketches

Unnamed: 0,0
7289,<sourmash.minhash.MinHash object at 0x7f864423...
5447,<sourmash.minhash.MinHash object at 0x7f86442d...
5159,<sourmash.minhash.MinHash object at 0x7f86442c...
3020,<sourmash.minhash.MinHash object at 0x7f864435...
1866,<sourmash.minhash.MinHash object at 0x7f86443a...
...,...
4280,<sourmash.minhash.MinHash object at 0x7f864431...
2146,<sourmash.minhash.MinHash object at 0x7f86443a...
6111,<sourmash.minhash.MinHash object at 0x7f864427...
1724,<sourmash.minhash.MinHash object at 0x7f864439...


In [34]:
#Save anchor sketches to send to Data Owner in Step 2

pickle.dump(anchor_sketches, open('anchor_sketches.dump','wb'))