In [1]:
import torch
from optimizer import LGSO

In [2]:
class Classifier(torch.nn.Module):
    def __init__(self, phi_dim,x_dim = 7, hidden_dim=256):
        super().__init__()
        self.fc1 = torch.nn.Linear(x_dim + phi_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = torch.nn.Linear(hidden_dim, 1)
        self.activation = torch.nn.ReLU()

    def forward(self, x):
        #x = torch.cat((phi,muons),axis = 1)
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.activation(self.fc3(x))
        x = torch.sigmoid(self.fc4(x))
        return x
    
class HitsClassifier():
    def __init__(self,
                 n_models:int = 1,
                 **classifier_kargs) -> None:
        self.models = [Classifier(**classifier_kargs) for i in range(n_models)]
        self.loss_fn = torch.nn.BCELoss()
    def fit(self,phi,y,x,n_epochs:int = 1000):
        inputs = torch.cat([phi.repeat(x.size(0),1),x.repeat(phi.size(0),1)],1)
        for model in self.models:
            optimizer = torch.optim.SGD(model.parameters(),lr = 0.1,momentum=0.9)
            for e in range(n_epochs):
                optimizer.zero_grad()
                p_hits = model(inputs)
                loss = self.loss_fn(p_hits,y)
                loss.backward()
                optimizer.step()
    def get_predictions(self,phi,x):
        inputs = torch.cat([phi.repeat(x.size(0),1),x.repeat(phi.size(0),1)],1)
        return torch.tensor([model(inputs) for model in self.models])
    def __call__(self,phi,x, return_unc = False):
        predictions = self.get_predictions(phi,x)
        if return_unc: return torch.mean(predictions,axis=0), torch.var(predictions,axis=0)
        else: return torch.mean(predictions,axis=0)

In [5]:
class ActiveLCSO(LGSO):
    def __init__(self,p_threshold:float, unc_threshold:float, phi_dim:int) -> None:
        self.p_threshold = p_threshold
        self.unc_threshold = unc_threshold
    def loss_fn(self,y):
        return torch.mean(torch.log(y))
    def get_uncertain(self,phi,x):
        predictions,unc = self.self(phi,x,return_unc = True)
        return torch.logical_and(predictions.ge(self.p_threshold),unc.ge(self.unc_threshold))
    def optimization_iteration(self):
        sampled_phi = self.sample_phi(self.current_phi)
        x = self.true_model.sample_x(sampled_phi)
        uncertain_mask = self.get_uncertain(sampled_phi,x)
        y = self.true_model.simulate(sampled_phi,x[uncertain_mask])
        self.update_history(sampled_phi,y,x[uncertain_mask])
        self.fit_surrogate_model()
        self.get_new_phi()
        return self.get_optimal()
    def clean_training_data(self):
        return self.history
