In [None]:
import pandas as pd
import numpy as np
import pickle
import argparse
import os

import torch

In [None]:
# set device to cpu or cuda
device = torch.device('cpu')

if(torch.cuda.is_available()):
    device = torch.device('cuda:4')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

# Arguments

In [None]:
# @title Arguments
parser = argparse.ArgumentParser(description='Actor Critic')

parser.add_argument('--data', default="/mnt/kerem/CEU", type=str, help='Dataset Path')
parser.add_argument('--epochs', default=64, type=int, metavar='N', help='Number of epochs for training agent.')
parser.add_argument('--episodes', default=10000, type=int, metavar='N', help='Number of episodes for training agent.')
parser.add_argument('--lr', '--learning-rate', default=0.005, type=float, metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--wd', default=0.0001, type=float, help='Weight decay for training optimizer')
parser.add_argument('--seed', default=3, type=int, help='Seed for reproducibility')
parser.add_argument('--model-name', default="PPO", type=str, help='Model name for saving model.')
parser.add_argument('--gamma', default=0.99, type=float, metavar='N', help='The discount factor as mentioned in the previous section')
parser.add_argument('--val_freq', default=50, type=int, metavar='N', help='Validation frequencies')

# Model
parser.add_argument("--latent1", default=256, required=False, help="Latent Space Size for first layer of network.")
parser.add_argument("--latent2", default=256, required=False, help="Latent Space Size for second layer of network.")

# Env Properties
parser.add_argument('--control_size', default=20, type=int, help='Beacon and Attacker Control group size')
parser.add_argument('--gene_size', default=100, type=int, help='States gene size')
parser.add_argument('--beacon_size', default=60, type=int, help='Beacon population size')
parser.add_argument('--victim_prob', default=0.8, type=float, help='Victim inside beacon or not!')
parser.add_argument('--pop_reset_freq', default=10, type=int, help='Reset Population Frequency (Epochs)')
parser.add_argument('--max_queries', default=10, type=int, help='Maximum queries per episode')


parser.add_argument("--state_dim", default=(4,), required=False, help="State Dimension")
parser.add_argument("--n-actions", default=1, required=False, help="Actions Count for each state")


# utils
parser.add_argument('--resume', default="", type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--save-dir', default='./results', type=str, metavar='PATH', help='path to cache (default: none)')

# args = parser.parse_args()  # running in command line
args = parser.parse_args('')  # running in ipynb

# set command line arguments here when running in ipynb
if args.save_dir == '':
    args.save_dir = "./"

args.results_dir = args.save_dir

if not os.path.exists(args.results_dir):
      os.makedirs(args.results_dir)

args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(args)

# Read Data

In [4]:
# CEU Beacon - it contains 164 people in total which we will divide into groups to experiment
beacon = pd.read_csv(os.path.join(args.data, "Beacon_164.txt"), index_col=0, delim_whitespace=True)
# Reference genome, i.e. the genome that has no SNPs, all major allele pairs for each position
reference = pickle.load(open(os.path.join(args.data, "reference.pickle"),"rb"))
# Binary representation of the beacon; 0: no SNP (i.e. no mutation) 1: SNP (i.e. mutation)
binary = np.logical_and(beacon.values != reference, beacon.values != "NN").astype(int)

In [6]:
# Table that contains MAF (minor allele frequency) values for each position. 
maf = pd.read_csv(os.path.join(args.data, "MAF.txt"), index_col=0, delim_whitespace=True)
maf.rename(columns = {'referenceAllele':'major', 'referenceAlleleFrequency':'major_freq', 
                      'otherAllele':'minor', 'otherAlleleFrequency':'minor_freq'}, inplace = True)
maf["maf"] = np.round(maf["maf"].values, 3)
# Same variable with sorted maf values
sorted_maf = maf.sort_values(by='maf')
# Extracting column to an array for future use
maf_values = maf["maf"].values

In [7]:
beacon.shape, reference.shape, binary.shape, maf_values.shape

((4029840, 164), (4029840, 1), (4029840, 164), (4029840,))

# PPO

In [8]:
has_continuous_action_space = True                

action_std = 0.4                    # starting std for action distribution (Multivariate Normal)
action_std_decay_rate = 0.05        # linearly decay action_std (action_std = action_std - action_std_decay_rate)
min_action_std = 0.1                # minimum action_std (stop decay after action_std <= min_action_std)
action_std_decay_freq = int(2.5e5)

################ PPO hyperparameters ################
K_epochs = 64           # update policy for K epochs
eps_clip = 0.2              # clip parameter for PPO
gamma = 0.99                # discount factor

lr_actor = 0.0003       # learning rate for actor network
lr_critic = 0.001       # learning rate for critic network

random_seed = 0         # set random seed if required (0 = no random seed)

In [9]:
%load_ext autoreload
%autoreload 2
from environment import BeaconEnv
from ppo import PPO
from engine import train

def main():
    env = BeaconEnv(args, beacon, maf_values, binary)
    state_dim = args.beacon_size * args.gene_size * 4
    action_dim = env.action_space.shape[0]

    # initialize a PPO agent
    ppo_agent = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space, action_std)

    train(args, env, ppo_agent)

if __name__ == '__main__':
    main()

Device set to : NVIDIA GeForce RTX 2080 Ti
Victim is inside the Beacon!


  self.attacker_state = torch.tensor([self.victim, self.mafs, [0]*len(self.victim)], dtype=torch.float32).transpose(0, 1)
  from .autonotebook import tqdm as notebook_tqdm


Started training at (GMT) :  2024-04-14 13:26:41
current logging run number for  :  2
logging at : ./results/PPO__log_2.csv
save checkpoint path : ./results/weights/PPO_2.pth
Reseting the Populations
Victim is inside the Beacon!
Episode:  0
lrt:  tensor(0.)
lrt:  tensor(0.)
lrt:  tensor(-9.5395e+15)
lrt:  tensor(-2.3519e+18)
lrt:  tensor(-2.3519e+18)
lrt:  tensor(-2.3519e+18)
lrt:  tensor(-2.3519e+18)
lrt:  tensor(-2.3519e+18)
lrt:  tensor(-2.3519e+18)
lrt:  tensor(-2.3519e+18)
Episode:  1
lrt:  tensor(0.)
lrt:  tensor(-60036872.)
lrt:  tensor(-60051392.)
lrt:  tensor(-61072748.)
lrt:  tensor(-47357208.)
lrt:  tensor(-47357208.)
lrt:  tensor(-47357208.)
lrt:  tensor(-8.0223e+10)
lrt:  tensor(-8.0223e+10)
lrt:  tensor(-8.0226e+10)
Episode:  2
lrt:  tensor(-6.2540e+09)
lrt:  tensor(-6.2541e+09)
lrt:  tensor(-6.2543e+09)
lrt:  tensor(-6.2558e+09)
lrt:  tensor(-6.2578e+09)
lrt:  tensor(-6.2578e+09)
lrt:  tensor(-6.2596e+09)
lrt:  tensor(-6.2596e+09)
lrt:  tensor(-6.2599e+09)
lrt:  tensor(-

KeyboardInterrupt: 