<a href="https://colab.research.google.com/github/chrlatte/TransferLearning/blob/main/Transfer_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install conplex-dti


In [None]:
!conplex-dti download --to . --models ConPLex_v1_BindingDB

2023-07-12 13:38:59,158 [INFO] Download Location: /content
2023-07-12 13:38:59,158 [INFO] 
2023-07-12 13:38:59,158 [INFO] [BENCHMARKS]
2023-07-12 13:38:59,158 [INFO] [MODELS]
2023-07-12 13:38:59,158 [INFO] Downloading ConPLex_v1_BindingDB from https://cb.csail.mit.edu/cb/conplex/data/models/BindingDB_ExperimentalValidModel.pt to /content/models/ConPLex_v1_BindingDB.pt...


In [None]:
import torch
import torch.nn as nn
class Cosine(nn.Module):
    def forward(self, x1, x2):
        return nn.CosineSimilarity()(x1, x2)


DISTANCE_METRICS = {
    "Cosine": Cosine
}

ACTIVATIONS = {"ReLU": nn.ReLU, "GELU": nn.GELU, "ELU": nn.ELU, "Sigmoid": nn.Sigmoid}


class SimpleCoembedding(nn.Module):
    def __init__(
        self,
        drug_shape=2048,
        target_shape=1024,
        latent_dimension=1024,
        latent_activation="ReLU",
        latent_distance="Cosine",
        classify=True,
    ):
        super().__init__()
        self.drug_shape = drug_shape
        self.target_shape = target_shape
        self.latent_dimension = latent_dimension
        self.do_classify = classify
        self.latent_activation = ACTIVATIONS[latent_activation]

        self.drug_projector = nn.Sequential(
            nn.Linear(self.drug_shape, latent_dimension),
            self.latent_activation()
        )
        nn.init.xavier_normal_(self.drug_projector[0].weight)

        self.target_projector = nn.Sequential(
            nn.Linear(self.target_shape, latent_dimension), self.latent_activation()
        )
        nn.init.xavier_normal_(self.target_projector[0].weight)

        if self.do_classify:
            self.distance_metric = latent_distance
            self.activator = DISTANCE_METRICS[self.distance_metric]()

    def forward(self, drug, target):
        if self.do_classify:
            return self.classify(drug, target)
        else:
            return self.regress(drug, target)

    def regress(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        inner_prod = torch.bmm(
            drug_projection.view(-1, 1, self.latent_dimension),
            target_projection.view(-1, self.latent_dimension, 1),
        ).squeeze()
        return inner_prod.squeeze()

    def classify(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        distance = self.activator(drug_projection, target_projection)
        return distance.squeeze()


SimpleCoembeddingNoSigmoid = SimpleCoembedding

In [None]:
model = SimpleCoembeddingNoSigmoid(2048, 1024, 1024)
model.load_state_dict(torch.load("/content/models/ConPLex_v1_BindingDB.pt", map_location="cpu"))
# model = model.eval()
# model = model.to(device)


<All keys matched successfully>

In [None]:
print(model)

SimpleCoembedding(
  (drug_projector): Sequential(
    (0): Linear(in_features=2048, out_features=1024, bias=True)
    (1): ReLU()
  )
  (target_projector): Sequential(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): ReLU()
  )
  (activator): Cosine()
)


In [None]:
for param in model.parameters():
  # print(param.data)
  print(param)


Parameter containing:
tensor([[ 0.0047, -0.0079,  0.0238,  ...,  0.0111,  0.0210,  0.0207],
        [ 0.0153,  0.0114, -0.0300,  ...,  0.0068, -0.0462, -0.0392],
        [ 0.0016,  0.0115, -0.0180,  ..., -0.0070, -0.0049, -0.0085],
        ...,
        [ 0.0366, -0.0195,  0.0033,  ...,  0.0002, -0.0111, -0.0228],
        [ 0.0404,  0.0027,  0.0098,  ..., -0.0241, -0.0017, -0.0129],
        [-0.0063,  0.0096, -0.0328,  ...,  0.0028,  0.0066, -0.0292]],
       requires_grad=True)
Parameter containing:
tensor([-0.0171, -0.0019, -0.0151,  ..., -0.0408, -0.0276, -0.0064],
       requires_grad=True)
Parameter containing:
tensor([[-0.0145, -0.0040,  0.0165,  ...,  0.0232,  0.0239, -0.0087],
        [ 0.0004,  0.0061, -0.0367,  ..., -0.0063, -0.0146, -0.0555],
        [-0.0189, -0.0549,  0.0328,  ..., -0.0027, -0.0060,  0.0128],
        ...,
        [ 0.0331, -0.0179, -0.0203,  ...,  0.0261, -0.0122, -0.0455],
        [ 0.0083, -0.0030,  0.0434,  ...,  0.0112,  0.0150, -0.0119],
        [-0.01

In [None]:
class TransferCoembedding(nn.Module):
    def __init__(
        self,
        pre_trained_model:SimpleCoembedding,
        drug_shape=2048,
        target_shape=1024,
        latent_dimension=1024,
        latent_activation="ReLU",
        latent_distance="Cosine",
        classify=True,
    ):
        super().__init__()
        # TODO: initialize these all baased on the pre-trained model
        self.drug_shape = drug_shape
        self.target_shape = target_shape
        self.latent_dimension = latent_dimension
        self.do_classify = classify
        self.latent_activation = ACTIVATIONS[latent_activation]

        self.drug_projector = nn.Sequential(
            nn.Linear(self.drug_shape, latent_dimension),     # [0]
            self.latent_activation(),                         # [1]
            # ADD AN ADDITIONAL LAYER AND ACTIVATION FUNCTION
            nn.Linear(latent_dimension, latent_dimension),    # [2]
            self.latent_activation()                          # [3]
        )

        # initialize layer 0 from pre-trained model:
        self.drug_projector[0] = pre_trained_model.drug_projector[0]
        # TODO: freeze layer 0
        self.drug_projector[0].requires_grad = False

        # TODO: initialize layer 2 randomly
        nn.init.xavier_normal_(self.drug_projector[2].weight)

        self.target_projector = nn.Sequential(
            nn.Linear(self.target_shape, latent_dimension),
            self.latent_activation(),
            # ADD AN ADDITIONAL LAYER AND ACTIVATION FUNCTION
            nn.Linear(latent_dimension, latent_dimension),
            self.latent_activation()
        )

        # initialize layer 0 from pre-trained model:
        self.target_projector[0] = pre_trained_model.target_projector[0]
        # TODO: freeze layer 0
        self.target_projector[0].requires_grad = False
        # TODO: initialize layer 2 randomly
        nn.init.xavier_normal_(self.target_projector[2].weight)

        if self.do_classify: # if True:
            self.distance_metric = latent_distance # "Cosine"
            self.activator = DISTANCE_METRICS[self.distance_metric]() # gives it the Cosine activator function that was written

    def forward(self, drug, target):
        if self.do_classify: # if True:
            return self.classify(drug, target)
        else:
            return self.regress(drug, target)

    def classify(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        distance = self.activator(drug_projection, target_projection)
        return distance.squeeze()

    def regress(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        inner_prod = torch.bmm(
            drug_projection.view(-1, 1, self.latent_dimension),
            target_projection.view(-1, self.latent_dimension, 1),
        ).squeeze()
        return inner_prod.squeeze()


In [None]:
new_model = TransferCoembedding(model)
print(new_model)

TransferCoembedding(
  (drug_projector): Sequential(
    (0): Linear(in_features=2048, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=1024, bias=True)
    (3): ReLU()
  )
  (target_projector): Sequential(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=1024, bias=True)
    (3): ReLU()
  )
  (activator): Cosine()
)


In [None]:
for param in new_model.parameters():
  print(param)

Parameter containing:
tensor([[ 0.0047, -0.0079,  0.0238,  ...,  0.0111,  0.0210,  0.0207],
        [ 0.0153,  0.0114, -0.0300,  ...,  0.0068, -0.0462, -0.0392],
        [ 0.0016,  0.0115, -0.0180,  ..., -0.0070, -0.0049, -0.0085],
        ...,
        [ 0.0366, -0.0195,  0.0033,  ...,  0.0002, -0.0111, -0.0228],
        [ 0.0404,  0.0027,  0.0098,  ..., -0.0241, -0.0017, -0.0129],
        [-0.0063,  0.0096, -0.0328,  ...,  0.0028,  0.0066, -0.0292]])
Parameter containing:
tensor([-0.0171, -0.0019, -0.0151,  ..., -0.0408, -0.0276, -0.0064])
Parameter containing:
tensor([[ 0.0059,  0.0216, -0.0063,  ...,  0.0296, -0.0172,  0.0210],
        [ 0.0604,  0.0188,  0.0035,  ...,  0.0151, -0.0138, -0.0070],
        [-0.0070,  0.0336,  0.0219,  ...,  0.0250,  0.0536, -0.0202],
        ...,
        [ 0.0482, -0.0056,  0.0183,  ...,  0.0621, -0.0617, -0.0180],
        [-0.0243, -0.0202,  0.0776,  ...,  0.0086,  0.0609, -0.0489],
        [ 0.0073, -0.0032, -0.0295,  ...,  0.0618,  0.0370,  0.0129]