Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
  • Loading branch information
arunppsg committed Jul 21, 2023
1 parent 73286a9 commit 29b796e
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions deepchem/models/torch_models/infograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import GRU, Linear, ReLU, Sequential
from typing import Iterable, List, Tuple, Optional
from typing import Iterable, List, Tuple, Optional, Dict
from deepchem.metrics import to_one_hot

import deepchem as dc
Expand Down Expand Up @@ -390,7 +390,7 @@ def build_components(self) -> dict:
fc1: MultilayerPerceptron, dense layer used during finetuning
fc2: MultilayerPerceptron, dense layer used during finetuning
"""
components = {}
components: Dict[str, nn.Module] = {}
if self.task == 'pretraining':
components['encoder'] = GINEncoder(self.num_features,
self.embedding_dim,
Expand All @@ -413,15 +413,17 @@ def build_components(self) -> dict:
self.num_gc_layers)
components['fc1'] = torch.nn.Linear(self.embedding_dim,
self.embedding_dim)
# n_tasks is Optional[int] while argument 2 of nn.Linear has to be of type int
components['fc2'] = torch.nn.Linear(self.embedding_dim,
self.n_tasks)
self.n_tasks) # type: ignore
return components

def build_model(self) -> nn.Module:
if self.task == 'pretraining':
return InfoGraph(**self.components)
model = InfoGraph(**self.components)
elif self.task == 'regression':
return InfoGraphFinetune(**self.components)
model = InfoGraphFinetune(**self.components)
return model

def loss_func(self, inputs, labels, weights):
if self.task == 'pretraining':
Expand Down

0 comments on commit 29b796e

Please sign in to comment.