# Data Generation Modeling of ZombieVenom

In [1]:
import numpy as np
import pandas as pd
import csv

from sklearn.utils import shuffle
import torch
from torch.distributions.one_hot_categorical import OneHotCategorical 
import pyro
import pyro.distributions as dist

In [2]:
def disease_model(prob_vec):
    
    if len(prob_vec) != 8:
        print ("Number of disease parameter are eight")
        return None
        
    #distribuição de paralisia para doença zulombriga
    pain = pyro.sample('pain', dist.Bernoulli(prob_vec[0]))
    pain = 1 if pain.item() == 1.0 else 0
    
    bleeding = pyro.sample('bleeding', dist.Bernoulli(prob_vec[1]))
    bleeding = 1 if bleeding.item() == 1.0 else 0
    
    swelling_limb = pyro.sample('swelling_limb', dist.Bernoulli(prob_vec[2]))
    swelling_limb = 1 if swelling_limb.item() == 1.0 else 0
    
    fallen_eye = pyro.sample('fallen_eye', dist.Bernoulli(prob_vec[3]))
    fallen_eye = 1 if fallen_eye.item() == 1.0 else 0

    hard_swallowing = pyro.sample('hard_swallowing', dist.Bernoulli(prob_vec[4]))
    hard_swallowing = 1 if hard_swallowing.item() == 1.0 else 0
    
    short_breath = pyro.sample('short_breath', dist.Bernoulli(prob_vec[5]))
    short_breath = 1 if short_breath.item() == 1.0 else 0
    
    necrosis = pyro.sample('necrosis', dist.Bernoulli(prob_vec[6]))
    necrosis = 1 if necrosis.item() == 1.0 else 0
    
    arrhythmia = pyro.sample('arrhythmia', dist.Bernoulli(prob_vec[7]))
    arrhythmia = 1 if arrhythmia.item() == 1.0 else 0
     
    return [pain, bleeding, swelling_limb, fallen_eye, hard_swallowing, short_breath, necrosis, arrhythmia]

In [3]:
def generate_data(n_cases = 20):
    
    #Model probabilities of the diseases
    jararaca_prob = [1.0, 0.8, 1.0, 0.0, 0.0, 0.0, 0.5, 0.0]
    cascavel_prob = [0.5, 0.0, 0.5, 1.0, 0.5, 0.5, 0.0, 0.0]
    aranha_prob = [0.8, 0.0, 0.8, 0.0, 0.0, 0.0, 0.8, 0.0]
    escorpiao_prob = [0.5, 0.0, 1.0, 0.0, 0.0, 0.5, 0.0, 1.0]


    cases = []
    
    for c in range(n_cases):
        
        disease_vec = ["jararaca", "cascavel", "aranha marrom", "escorpião"]
        diseases_prob = [jararaca_prob, cascavel_prob, aranha_prob, escorpiao_prob]
        # probability of disease occuring
        disease_dist = OneHotCategorical(torch.tensor([0.25, 0.25, 0.25, 0.25]))
        d_idx = np.argmax(disease_dist.sample())
        diagnosis = disease_vec[d_idx]
        symptoms = disease_model(diseases_prob[d_idx])
        
        line = symptoms + [diagnosis]
        
        cases.append(line)
        
    
    return cases

In [4]:
def save_to_csv(mylist, filename):
    
    header = ["Dor", "Sangramento", "Membro inchado", "Olho caído", "Dificuldade deglutição",
              "Insuficiência respiratória", "Necrose","Arritmia cardíaca", "Diagnóstico"]
#     h = ','.join(header)
    
    with open(filename, 'w', newline='') as myfile:
        wr = csv.writer(myfile, delimiter=',')
        wr.writerow(header)
        for item in mylist:
             
             wr.writerow(item)

In [30]:
cases = generate_data(n_cases=500)
cases = shuffle(cases)
for i in range(0,100,10):
    cases10 = np.array(cases[i:i+10])
    print(cases10)
    if ('jararaca' in cases10) and ('cascavel' in cases10) and ('aranha marrom' in cases10) and ('escorpião' in cases10):
        save_to_csv(cases10, "ZombieVenom"+str(int(i/10))+".csv")

[['0' '0' '1' '0' '0' '0' '0' '1' 'escorpião']
 ['1' '0' '1' '0' '0' '0' '1' '0' 'aranha marrom']
 ['1' '1' '1' '0' '0' '0' '1' '0' 'jararaca']
 ['0' '0' '1' '0' '0' '0' '0' '1' 'escorpião']
 ['0' '0' '1' '0' '0' '1' '0' '1' 'escorpião']
 ['1' '0' '1' '0' '0' '0' '1' '0' 'jararaca']
 ['0' '0' '0' '1' '0' '0' '0' '0' 'cascavel']
 ['1' '0' '1' '0' '0' '0' '0' '1' 'escorpião']
 ['1' '1' '1' '0' '0' '0' '0' '0' 'jararaca']
 ['1' '0' '1' '0' '0' '0' '0' '1' 'escorpião']]
[['0' '0' '1' '0' '0' '0' '1' '0' 'aranha marrom']
 ['1' '1' '1' '0' '0' '0' '1' '0' 'jararaca']
 ['1' '1' '1' '0' '0' '0' '1' '0' 'jararaca']
 ['1' '0' '1' '0' '0' '0' '0' '1' 'escorpião']
 ['1' '0' '1' '1' '0' '0' '0' '0' 'cascavel']
 ['0' '0' '0' '1' '0' '0' '0' '0' 'cascavel']
 ['0' '0' '1' '0' '0' '0' '0' '1' 'escorpião']
 ['1' '0' '0' '1' '0' '1' '0' '0' 'cascavel']
 ['1' '1' '1' '0' '0' '0' '0' '0' 'jararaca']
 ['1' '0' '1' '0' '0' '0' '1' '0' 'aranha marrom']]
[['1' '0' '1' '0' '0' '0' '0' '1' 'escorpião']
 ['1' '0'