<a href="https://colab.research.google.com/github/chrlatte/TransferLearning/blob/main/Transfer_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install conplex-dti


In [None]:
!conplex-dti download --to . --models ConPLex_v1_BindingDB

2023-07-12 13:38:59,158 [INFO] Download Location: /content
2023-07-12 13:38:59,158 [INFO] 
2023-07-12 13:38:59,158 [INFO] [BENCHMARKS]
2023-07-12 13:38:59,158 [INFO] [MODELS]
2023-07-12 13:38:59,158 [INFO] Downloading ConPLex_v1_BindingDB from https://cb.csail.mit.edu/cb/conplex/data/models/BindingDB_ExperimentalValidModel.pt to /content/models/ConPLex_v1_BindingDB.pt...


In [9]:
import torch
import torch.nn as nn
class Cosine(nn.Module):
    def forward(self, x1, x2):
        return nn.CosineSimilarity()(x1, x2)


DISTANCE_METRICS = {
    "Cosine": Cosine
}

ACTIVATIONS = {"ReLU": nn.ReLU, "GELU": nn.GELU, "ELU": nn.ELU, "Sigmoid": nn.Sigmoid}


class SimpleCoembedding(nn.Module):
    def __init__(
        self,
        drug_shape=2048,
        target_shape=1024,
        latent_dimension=1024,
        latent_activation="ReLU",
        latent_distance="Cosine",
        classify=True,
    ):
        super().__init__()
        self.drug_shape = drug_shape
        self.target_shape = target_shape
        self.latent_dimension = latent_dimension
        self.do_classify = classify
        self.latent_activation = ACTIVATIONS[latent_activation]

        self.drug_projector = nn.Sequential(
            nn.Linear(self.drug_shape, latent_dimension),
            self.latent_activation()
        )
        nn.init.xavier_normal_(self.drug_projector[0].weight)

        self.target_projector = nn.Sequential(
            nn.Linear(self.target_shape, latent_dimension), self.latent_activation()
        )
        nn.init.xavier_normal_(self.target_projector[0].weight)

        if self.do_classify:
            self.distance_metric = latent_distance
            self.activator = DISTANCE_METRICS[self.distance_metric]()

    def forward(self, drug, target):
        if self.do_classify:
            return self.classify(drug, target)
        else:
            return self.regress(drug, target)

    def regress(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        inner_prod = torch.bmm(
            drug_projection.view(-1, 1, self.latent_dimension),
            target_projection.view(-1, self.latent_dimension, 1),
        ).squeeze()
        return inner_prod.squeeze()

    def classify(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        distance = self.activator(drug_projection, target_projection)
        return distance.squeeze()


SimpleCoembeddingNoSigmoid = SimpleCoembedding

In [11]:
model = SimpleCoembeddingNoSigmoid(2048, 1024, 1024)
model.load_state_dict(torch.load("/content/models/ConPLex_v1_BindingDB.pt", map_location="cpu"))
# model = model.eval()
# model = model.to(device)


<All keys matched successfully>

In [13]:
print(model)

SimpleCoembedding(
  (drug_projector): Sequential(
    (0): Linear(in_features=2048, out_features=1024, bias=True)
    (1): ReLU()
  )
  (target_projector): Sequential(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): ReLU()
  )
  (activator): Cosine()
)


In [17]:
model = NeuralNetwork()
print(model)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


In [23]:
class TransferCoembedding(nn.Module):
    def __init__(
        self,
        drug_shape=2048,
        target_shape=1024,
        latent_dimension=1024,
        latent_activation="ReLU",
        latent_distance="Cosine",
        classify=True,
    ):
        super().__init__()
        self.drug_shape = drug_shape
        self.target_shape = target_shape
        self.latent_dimension = latent_dimension
        self.do_classify = classify
        self.latent_activation = ACTIVATIONS[latent_activation]

        self.drug_projector = nn.Sequential(
            nn.Linear(self.drug_shape, latent_dimension),     # [0]
            self.latent_activation(),                         # [1]
            # ADD AN ADDITIONAL LAYER AND ACTIVATION FUNCTION
            nn.Linear(latent_dimension, latent_dimension),    # [2]
            self.latent_activation()                          # [3]
        )

        # TODO: initialize layer 0 from the pre-trained model, then freeze!
        # TODO: initialize layer 1 randomly

        nn.init.xavier_normal_(self.drug_projector[0].weight)

        self.target_projector = nn.Sequential(
            nn.Linear(self.target_shape, latent_dimension),
            self.latent_activation(),
            # ADD AN ADDITIONAL LAYER AND ACTIVATION FUNCTION
            nn.Linear(latent_dimension, latent_dimension),
            self.latent_activation()
        )

        # TODO: same as above
        nn.init.xavier_normal_(self.target_projector[0].weight)

        if self.do_classify: # if True:
            self.distance_metric = latent_distance # "Cosine"
            self.activator = DISTANCE_METRICS[self.distance_metric]() # gives it the Cosine activator function that was written

    def forward(self, drug, target):
        if self.do_classify: # if True:
            return self.classify(drug, target)
        else:
            return self.regress(drug, target)

    def classify(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        distance = self.activator(drug_projection, target_projection)
        return distance.squeeze()

    def regress(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        inner_prod = torch.bmm(
            drug_projection.view(-1, 1, self.latent_dimension),
            target_projection.view(-1, self.latent_dimension, 1),
        ).squeeze()
        return inner_prod.squeeze()
