Skip to content

Commit

Permalink
add code of Transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Jan 5, 2024
1 parent 7715ce0 commit 4801583
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions other.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,37 @@ def forgetting_curve(self, t, s):
return 0.9 ** (t / s)


class Transformer(nn.Module):
def __init__(self, state_dict=None):
super().__init__()
self.n_input = n_input
self.n_hidden = n_hidden
self.n_output = n_output
self.n_layers = n_layers
self.transformer = nn.Transformer(
d_model=n_input,
nhead=n_input,
num_encoder_layers=n_layers * 2,
num_decoder_layers=0,
dim_feedforward=n_hidden,
dropout=0,
)
self.fc = nn.Linear(n_input, n_output)

if state_dict is not None:
self.load_state_dict(state_dict)

def forward(self, x):
output = self.transformer.encoder(x)
output = output.mean(dim=0)
output = self.fc(output)
output = torch.exp(output).repeat(x.shape[0], 1, 1)
return output, None

def forgetting_curve(self, t, s):
return 0.9 ** (t / s)


class HLR(nn.Module):
def __init__(self, state_dict=None):
super().__init__()
Expand Down Expand Up @@ -498,7 +529,7 @@ def __init__(self, MDOEL) -> None:

def predict(self, t_history: str, r_history: str):
with torch.no_grad():
if isinstance(self.model, RNN):
if isinstance(self.model, RNN) or isinstance(self.model, Transformer):
line_tensor = lineToTensorRNN(
list(zip([t_history], [r_history]))[0]
).unsqueeze(1)
Expand Down Expand Up @@ -587,7 +618,7 @@ def create_features(df, model_name="FSRSv3"):
for t_sublist, r_sublist in zip(t_history, r_history)
for t_item, r_item in zip(t_sublist, r_sublist)
]
elif model_name == "LSTM":
elif model_name in ("LSTM", "Transformer"):
df["tensor"] = [
torch.tensor(
[t_item[:-1]]
Expand Down Expand Up @@ -635,6 +666,8 @@ def process(args):
model = FSRS3
elif model_name == "HLR":
model = HLR
elif model_name == "Transformer":
model = Transformer

dataset = create_features(dataset, model_name)
if dataset.shape[0] < 6:
Expand Down

0 comments on commit 4801583

Please sign in to comment.