Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions +gpt2/model.m
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
% model A GPT-2 model
%
% [logits, presents] = model(X, pasts, parameters) performs prediction
% with a GPT-2 model on the input X. X is a 1-by-numInputSubwords array
% of tokenized text, and the model returns an array logits that is
% 50257-by-numInputSubwords. This array can be used to predict the next
% subword. See below for more details of inputs and outputs.
% with a GPT-2 model on the input X. X is a
% 1-by-numInputSubwords-by-numObs array of tokenized text, and the model
% returns an array logits that is 50257-by-numInputSubwords-by-numObs.
% This array can be used to predict the next subword. See below for more
% details of inputs and outputs.
%
% Inputs:
% X - A 1-by-numInputSubwords array. This array is a
% tokenized sentence. It should be created using
% the tokenizer for GPT-2.
% X - A 1-by-numInputSubwords-by-numObs array. This
% array is a tokenized sentence. It should be
% created using the tokenizer for GPT-2.
% pasts - A numLayers-by-1 cell array containing "keys"
% and "values" for the attention layers. These
% come from the previous subwords in the text we
Expand Down Expand Up @@ -40,14 +41,14 @@
% normalization.
%
% Outputs:
% logits - A 50257-by-numInputSubwords array of logits
% (pre-softmax outputs). If we apply softmax to
% this array, we get the probabilities for the
% next subword. However, we usually want to do
% more pre-processing before doing this (like
% taking the top-K entries). 50257 is the number
% of subwords in the vocabulary for GPT-2's
% tokenizer.
% logits - A 50257-by-numInputSubwords-by-numObs array of
% logits (pre-softmax outputs). If we apply
% softmax to this array, we get the probabilities
% for the next subword. However, we usually want
% to do more pre-processing before doing this
% (like taking the top-K entries). 50257 is the
% number of subwords in the vocabulary for
% GPT-2's tokenizer.
% presents - A numLayers-by-1 cell array containing "keys"
% and "values" from the attention blocks. We feed
% these back in as the 'pasts' input.
Expand All @@ -57,9 +58,13 @@

% Apply the embedding. If there are inputs for the "past", we need to
% offset the position embedding to account for this.
% Word embedding
seqLen = size(X, 2);
h = weights.wte_0(:, X);
h = reshape(h, size(h,1), seqLen, []);
% Positional embedding
positionOffset = size(pasts{1},2);
h = weights.wte_0( :,X ) + ...
weights.wpe_0( :, positionOffset + (1:length(X)) );
h = h + weights.wpe_0(:, positionOffset + (1:seqLen) );

% Run the layers
presents = cell(hyperparameters.NumLayers,1);
Expand All @@ -74,7 +79,7 @@
weights.ln_f_g_0, ...
weights.ln_f_b_0 );

% Calculate logits (50257-by-numInputSubwords)
logits = weights.wte_0'*h;
% Calculate logits (50257-by-numInputSubwords-by-numObs)
logits = dlmtimes(weights.wte_0', h);

end
47 changes: 23 additions & 24 deletions +transformer/+layer/attention.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
% 2 in [1]. See below for details of inputs and outputs.
%
% Inputs:
% X - A (numFeatures*numHeads)-by-numInputSubwords
% X - A (numFeatures*numHeads)-by-numInputSubwords-by-numObs
% input array.
% past - A numFeatures-by-numPastSubwords-by-numHeads-by-2
% past - A numFeatures-by-numPastSubwords-by-numHeads-by-numObs-by-2
% array. This contains the 'keys' and 'values' for
% past subwords. These are needed to predict future
% outputs in an autoregressive manner. 'keys' are
% stored in past(:,:,:,1) and 'values' are stored
% in past(:,:,:,2).
% stored in past(:,:,:,:,1) and 'values' are stored
% in past(:,:,:,:,2).
% weights - The weights for the full multi-head attention
% block stored in a struct. This includes:
% - attn_c_attn_w_0: A weight matrix for the
Expand All @@ -28,15 +28,15 @@
% hyper-parameter.
%
% Outputs:
% Z - A (numFeatures*numHeads)-by-numInputSubwords
% Z - A (numFeatures*numHeads)-by-numInputSubwords-by-numObs
% output array.
% present - A numFeatures-by-numAllSubwords-by-numHeads-by-2
% present - A numFeatures-by-numAllSubwords-by-numHeads-by-numObs-by-2
% array. This contains the 'keys' and 'values' that
% are created from inputs. These need to passed
% back in as the 'past' input if we want to predict
% future outputs in an autoregressive manner. 'keys'
% are stored in present(:,:,:,1) and 'values' are
% stored in present(:,:,:,2).
% are stored in present(:,:,:,:,1) and 'values' are
% stored in present(:,:,:,:,2).
%
% References:
%
Expand All @@ -52,9 +52,9 @@

% Split the results into Q (Query), K (Keys) and V (Values).
splitSize = size(C,1)/3;
Q = C(1:splitSize,:);
K = C((splitSize+1):(2*splitSize),:);
V = C((2*splitSize+1):(3*splitSize),:);
Q = C(1:splitSize,:,:);
K = C((splitSize+1):(2*splitSize),:,:);
V = C((2*splitSize+1):(3*splitSize),:,:);

% Split heads
Q = iSplitHeads(Q, splitSize, hyperParameters.NumHeads);
Expand All @@ -63,16 +63,16 @@

% Use the past
if ~isempty(past)
PK = past(:,:,:,1);
PV = past(:,:,:,2);
PK = past(:,:,:,:,1);
PV = past(:,:,:,:,2);
K = cat(2,PK,K);
V = cat(2,PV,V);
end

% Set present. Note that this is done differently from the original
% implementation which sets the value of present before the previous if
% statement.
present = cat(4,K,V);
% statement
present = cat(5,K,V);

A = transformer.layer.multiheadAttention(Q,K,V);

Expand All @@ -81,23 +81,22 @@
A = transformer.layer.convolution1d( A, ...
weights.attn_c_proj_w_0, ...
weights.attn_c_proj_b_0 );

end

function Z = iSplitHeads(X, splitSize, numHeads)
% We permute the data to put the dimension for the heads last, so that we
% can use batched matrix multiplication to compute attention for all of the
% heads at once.
%
% X - A (numFeatures*numHeads)-by-numSubwords array.
% Z - A numFeatures-by-numSubwords-by-numHeads array.
X = reshape(X, splitSize/numHeads, numHeads, []);
Z = permute(X,[1 3 2]);
% X - A (numFeatures*numHeads)-by-numSubwords-by-numObs array.
% Z - A numFeatures-by-numSubwords-by-numHeads-by-numObs array.
X = reshape(X, splitSize/numHeads, numHeads, [], size(X,3));
Z = permute(X,[1 3 2 4]);
end

function Z = iMergeHeads(X)
% X - A numFeatures-by-numSubwords-by-numHeads array.
% Z - A (numFeatures*numHeads)-by-numSubwords array.
X = permute(X, [1 3 2]);
Z = reshape(X, size(X,1)*size(X,2), []);
% X - A numFeatures-by-numSubwords-by-numHeads-by-numObs array.
% Z - A (numFeatures*numHeads)-by-numSubwords-by-numObs array.
X = permute(X, [1 3 2 4]);
Z = reshape(X, size(X,1)*size(X,2), [], size(X,4));
end
2 changes: 1 addition & 1 deletion +transformer/+layer/convolution1d.m
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
% Output:
% Z - A numOutputFeatures-by-numInputSubwords array.

Z = W*X + b;
Z = dlmtimes(W,X) + b;

end
14 changes: 6 additions & 8 deletions +transformer/+layer/multiheadAttention.m
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
% below for details.
%
% Inputs:
% Q - A numFeatures-by-numInputSubWords-by-numHeads array of
% queries.
% K - A numFeatures-by-numAllSubWords-by-numHeads array of keys.
% V - A numFeatures-by-numAllSubWords-by-numHeads array of values.
% Q - numFeatures-by-numInputSubWords-by-numHeads-by-numObs array of queries.
% K - numFeatures-by-numAllSubWords-by-numHeads-by-numObs array of keys.
% V - numFeatures-by-numAllSubWords-by-numHeads-by-numObs array of values.
%
% Outputs:
% A - A numFeatures-by-numInputSubWords-by-numHeads array of
% attention matrices.
% A - numFeatures-by-numInputSubWords-by-numHeads-by-numObs array of attention matrices.
%
% References:
%
Expand All @@ -29,7 +27,7 @@
% matrices. W is numAllSubWords-by-numInputSubWords-by-numHeads. Each
% element of W is the dot product of a query vector from Q and a key vector
% from K.
W = dlmtimes(permute(K, [2 1 3]), Q);
W = dlmtimes(permute(K, [2 1 3 4]), Q);

% Divide by square root of d
W = W./sqrt(size(Q,1));
Expand All @@ -38,7 +36,7 @@
W = transformer.layer.maskAttentionWeights(W);

% Apply softmax
W = softmax(W, 'DataFormat', 'CTB');
W = softmax(W, 'DataFormat', 'CTUB');

% We compute the attention by taking products between the attention weights
% W and V. A is numFeatures-by-numInputSubWords-by-numHeads. One
Expand Down
4 changes: 2 additions & 2 deletions +transformer/+layer/normalization.m
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
% Layer normzalization is described in [1].
%
% Inputs:
% X - A numFeatures-by-numInputSubwords input array.
% X - A numFeatures-by-numInputSubwords-by-numObs input array.
% g - A numFeatures-by-1 weight vector.
% b - A numFeatures-by-1 bias vector.
%
% Outputs:
% Z - A numFeatures-by-numInputSubwords output array.
% Z - A numFeatures-by-numInputSubwords-by-numObs output array.
%
% References:
%
Expand Down
2 changes: 1 addition & 1 deletion test/gpt2/layer/tblock.m
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function outputHasInputSizeWithPasts(test,Input)
% Provide a fake past of sequence length 1
K_fake = dlarray(rand(C,1));
V_fake = dlarray(rand(C,1));
past = cat(4,K_fake,V_fake);
past = cat(5,K_fake,V_fake);
[y,present] = test.block(x,past,weights,hyperParameters);
test.verifySize(y,size(x));
% The size of presents is the size of past except the sequence
Expand Down
52 changes: 42 additions & 10 deletions test/gpt2/tmodel.m
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,49 @@
model = @gpt2.model
end

properties(TestParameter)
InputData = iGetInputData();
end

methods(Test)
function canUseModel(test)
inputs = test.prepareInputs();
test.verifyWarningFree(@() test.model(inputs{:}));
function canUseModel(test, InputData)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to see a new test that checks each batched operation matches calling the operation on each observation in turn, i.e. f([x1,x2]) = [f(x1),f(x2)] in some sense.

Doing that at the model level will be most efficient, but it'd be good practice to do that at the unit level too. I can take a look at this too, but I think I'd rather see it before merging into master.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely. I was using such a test locally to confirm my changes. I'll add something to tmodel.

X = InputData;
[pasts, parameters] = test.prepareInputs();
test.verifyWarningFree(@() test.model(X, pasts, parameters));
end

function canAcceptBatches(test)
% gpt2.model should be able to accept multiple observations
% with the same sequence length

% Create inputs
[pasts, parameters] = test.prepareInputs();
numObs = 4;
seqLen = 5;
vocabSize = size( parameters.Weights.wte_0, 2 );
X = randi(vocabSize, [1 seqLen numObs]);

% Get batch results
Ybatch = test.model(X, pasts, parameters);

% Iterate over batch
YperObs = dlarray(zeros([vocabSize seqLen numObs], 'single'));
for i = 1:numObs
YperObs(:, :, i) = test.model(X(:, :, i), pasts, parameters);
end

% Verify the results are within a relative tolerance for single
% precision data
test.verifyEqual(extractdata(Ybatch), extractdata(YperObs), 'RelTol', single(1e-5));
end
end

methods(Access=private)
function inputs = prepareInputs(test)
function [pasts, parameters] = prepareInputs(test)
% Convenience method to setup inputs for
% transformer.model
X = test.prepareX();
parameters = test.prepareParameters();
pasts = test.preparePasts(parameters.Hyperparameters.NumLayers);
inputs = {X,pasts,parameters};
end

function X = prepareX(~)
X = dlarray(1);
end

function pasts = preparePasts(~,numLayers)
Expand All @@ -37,4 +61,12 @@ function canUseModel(test)
parameters = gpt2.load(parametersFile);
end
end
end

function s = iGetInputData()
s = struct( ...
'SingleToken', dlarray(1), ...
'MultiSeqLen', dlarray([1 7 2 9]), ...
'MultiSeqLenAndObs', dlarray( permute([1 7 2 9; 7 2 1 9], [3 2 1]) ) ...
);
end
Loading