In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from MayaDataset import CNNCLF
import MayaDataset
from DefenderGAN import Warmup
from Models import RNNGenerator2, RNNInference, RNNGenerator3, RNNInference3
import argparse
import time
import os
victim = 'video'
dataset = MayaDataset.MayaDataset('traces/aml_video/',minpower=25, maxpower=225, window=900, labels=victim)

dsets = random_split(dataset, [1000,1000,1000, len(dataset)-3000])
trainset = dsets[0]
trainloader = DataLoader(trainset, batch_size=32, num_workers=4)

valset = dsets[1]
valloader = DataLoader(valset, batch_size=32, num_workers=4)

testset = dsets[2]
testloader = DataLoader(testset, batch_size=32, num_workers=4)

dim=64
clf = CNNCLF(dataset.window//3).cuda()
gen = RNNGenerator3(dim).cuda()
#bestname = './best_{}_{}_{}.pth'.format(victim,'rnn2',dim)
bestname = './best_{}_{}.pth'.format('rnn3',dim)
if os.path.isfile(bestname):
    print('Previous best found: loading the model...')
    gen.load_state_dict(torch.load(bestname))
gen.train()


In [None]:
optim_c = torch.optim.Adam(clf.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
bestacc = 0.0
epochs = 30
prange = dataset.maxpower - dataset.minpower
pmin = dataset.minpower
for e in range(epochs):
    clf.train()
    for x,y in trainloader:
        optim_c.zero_grad()
        xdata, ydata = x.cuda(), y.cuda()
        perturb = gen(xdata)
        p_input = perturb.detach()
        output = clf(p_input)
        loss = criterion(output, ydata)
        loss.backward()
        optim_c.step()
    totcorrect = 0
    totcount = 0
    clf.eval()
    orgpower = 0.0
    newpower = 0.0
    for x,y in valloader:
        xdata, ydata = x.cuda(), y.cuda()
        orgpower += xdata.mean().item()
        perturb = gen(xdata)
        p_input = perturb.detach()
        newpower += p_input.mean().item()
        output = clf(p_input)
        pred = output.argmax(axis=-1)
        totcorrect += (pred==ydata).sum().item()
        totcount += y.size(0)
    macc = float(totcorrect)/totcount
    orgpower = orgpower/len(valloader)
    orgpower = orgpower*prange + pmin
    newpower = newpower/len(valloader)
    newpower = newpower*prange + pmin
    if(macc > bestacc) and e > epochs//2:
        bestacc = macc
        bestclf = clf.state_dict()
    print("Cooldown {}\t acc {:.4f}\torgpower {:.2f}\t newpower: {:.2f}".format(e+1, macc, orgpower, newpower))

clf.load_state_dict(bestclf)
totcorrect = 0
totcount = 0
clf.eval()
orgpower = 0.0
newpower = 0.0
for x,y in testloader:
    xdata, ydata = x.cuda(), y.cuda()
    orgpower += xdata.mean().item()
    perturb = gen(xdata)
    p_input = perturb.detach()
    newpower += p_input.mean().item()
    output = clf(p_input)
    pred = output.argmax(axis=-1)
    totcorrect += (pred==ydata).sum().item()
    totcount += y.size(0)
macc = float(totcorrect)/totcount
orgpower = orgpower/len(valloader)
orgpower = orgpower*prange + pmin
newpower = newpower/len(valloader)
newpower = newpower*prange + pmin
print("Test acc {:.4f}\torgpower {:.2f}\t newpower: {:.2f}".format(macc, orgpower, newpower))

In [None]:
clf.load_state_dict(bestclf)
totcorrect = 0
totcount = 0
clf.eval()
orgpower = 0.0
newpower = 0.0
for x,y in testloader:
    xdata, ydata = x.cuda(), y.cuda()
    orgpower += xdata.mean().item()
    perturb = gen(xdata)
    p_input = perturb.detach()
    newpower += p_input.mean().item()
    output = clf(p_input)
    pred = output.argmax(axis=-1)
    totcorrect += (pred==ydata).sum().item()
    totcount += y.size(0)
macc = float(totcorrect)/totcount
orgpower = orgpower/len(valloader)
orgpower = orgpower*prange + pmin
newpower = newpower/len(valloader)
newpower = newpower*prange + pmin
print("Test acc {:.4f}\torgpower {:.2f}\t newpower: {:.2f}".format(macc, orgpower, newpower))

In [None]:
train_x = []
train_y = []
valid_x = []
valid_y = []
for x,y in trainloader:
    xdata, ydata = x.cuda(), y.cuda()
    perturb = gen(xdata)
    p_input = perturb.detach()
    for p in p_input.cpu():
        train_x.append(p.numpy())
    for label in y:
        train_y.append(label)

for x,y in testloader:
    xdata, ydata = x.cuda(), y.cuda()
    perturb = gen(xdata)
    p_input = perturb.detach()
    for p in p_input.cpu():
        valid_x.append(p.numpy())
    for label in y:
        valid_y.append(label)
    

In [None]:

from sklearn import svm
clf = svm.LinearSVC()
clf.fit(train_x, train_y)
pred_y = clf.predict(valid_x)

In [None]:
correct = [yp == y for yp,y in zip(pred_y, valid_y)]
sum(correct)/len(correct)

In [None]:
print(clf.coef_.shape)
print(clf.intercept_.shape)

In [None]:
model = RNNInference(dim)
model.copy_params(gen.cpu())
script_module = torch.jit.trace(model, (torch.rand(1,32), torch.rand(1,dim)))

In [None]:
script_module.save('./cpuscript_{}_{}_{}.pt'.format(victim,'rnn3',dim))

In [None]:
aa = nn.Linear(32,64)
model.encoder[1].weight.shape
