Skip to content

Commit

Permalink
Merge branch 'rh/v43_model_load' into 'master'
Browse files Browse the repository at this point in the history
v4.3 model load

See merge request machine-learning/dorado!606
  • Loading branch information
iiSeymour committed Nov 2, 2023
2 parents 641cb08 + ee042dd commit ac2a90a
Show file tree
Hide file tree
Showing 16 changed files with 1,279 additions and 116 deletions.
75 changes: 43 additions & 32 deletions dorado/nn/CRFModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "utils/module_utils.h"
#include "utils/tensor_utils.h"

#include <stdexcept>

#if DORADO_GPU_BUILD && !defined(__APPLE__)
#define USE_KOI 1

Expand Down Expand Up @@ -40,16 +42,14 @@ using namespace torch::nn;
namespace F = torch::nn::functional;
using Slice = torch::indexing::Slice;

enum class Activation { SWISH, SWISH_CLAMP, TANH };

#if USE_KOI

KoiActivation get_koi_activation(Activation act) {
if (act == Activation::SWISH) {
KoiActivation get_koi_activation(dorado::Activation act) {
if (act == dorado::Activation::SWISH) {
return KOI_SWISH;
} else if (act == Activation::SWISH_CLAMP) {
} else if (act == dorado::Activation::SWISH_CLAMP) {
return KOI_SWISH_CLAMP;
} else if (act == Activation::TANH) {
} else if (act == dorado::Activation::TANH) {
return KOI_TANH;
} else {
throw std::logic_error("Unrecognised activation function id.");
Expand Down Expand Up @@ -89,14 +89,16 @@ KoiActivation get_koi_activation(Activation act) {

enum class LstmMode { CUBLAS_TN2C, QUANTISED_NTC, CUTLASS_TNC_I8, CUTLASS_TNC_F16 };

static LstmMode get_cuda_lstm_mode(int layer_idx, int layer_size) {
static LstmMode get_cuda_lstm_mode(int layer_idx, int layer_size, dorado::Activation activation) {
const char *env_lstm_mode = std::getenv("DORADO_LSTM_MODE");
if (env_lstm_mode != nullptr) {
std::string lstm_mode_str(env_lstm_mode);
if (lstm_mode_str == "CUBLAS_TN2C") {
return LstmMode::CUBLAS_TN2C;
} else if (lstm_mode_str == "CUTLASS_TNC_I8") {
return (layer_idx == 0) ? LstmMode::CUTLASS_TNC_F16 : LstmMode::CUTLASS_TNC_I8;
return (layer_idx == 0 && activation != dorado::Activation::TANH)
? LstmMode::CUTLASS_TNC_F16
: LstmMode::CUTLASS_TNC_I8;
} else if (lstm_mode_str == "CUTLASS_TNC_F16") {
return LstmMode::CUTLASS_TNC_F16;
}
Expand All @@ -107,7 +109,10 @@ static LstmMode get_cuda_lstm_mode(int layer_idx, int layer_size) {
bool is_A100_H100 = ((prop->major == 8 || prop->major == 9) && prop->minor == 0);

if (is_A100_H100 && layer_size <= 1024 && layer_size > 128 && (layer_size % 128) == 0) {
return (layer_idx == 0) ? LstmMode::CUTLASS_TNC_F16 : LstmMode::CUTLASS_TNC_I8;
// Zeroth LSTM can be quantised if the preceeding activation is TANH
return (layer_idx == 0 && activation != dorado::Activation::TANH)
? LstmMode::CUTLASS_TNC_F16
: LstmMode::CUTLASS_TNC_I8;
} else if (!is_TX2 && (layer_size == 96 || layer_size == 128)) {
return LstmMode::QUANTISED_NTC;
}
Expand Down Expand Up @@ -265,8 +270,8 @@ struct ConvolutionImpl : Module {
int64_t chunk_size_out = chunk_size_in / stride;
if (next_layer_is_lstm || in_size > 16) {
// For conv2 with in_size > 16 we can use the same codepath as QUANTISED_NTC
LstmMode lstm_mode =
next_layer_is_lstm ? get_cuda_lstm_mode(0, out_size) : LstmMode::QUANTISED_NTC;
LstmMode lstm_mode = next_layer_is_lstm ? get_cuda_lstm_mode(0, out_size, activation)
: LstmMode::QUANTISED_NTC;
switch (lstm_mode) {
case LstmMode::CUTLASS_TNC_I8:
wm.reserve({chunk_size_out, batch_size, window_size, in_size}, torch::kF16);
Expand Down Expand Up @@ -312,8 +317,8 @@ struct ConvolutionImpl : Module {

if (next_layer_is_lstm || in_size > 16) {
// For conv2 with in_size > 16 we can use the same codepath as QUANTISED_NTC
LstmMode lstm_mode =
next_layer_is_lstm ? get_cuda_lstm_mode(0, out_size) : LstmMode::QUANTISED_NTC;
LstmMode lstm_mode = next_layer_is_lstm ? get_cuda_lstm_mode(0, out_size, activation)
: LstmMode::QUANTISED_NTC;
at::Tensor ntwc_mat, tnwc_mat;
if (lstm_mode == LstmMode::QUANTISED_NTC) {
ntwc_mat = wm.next({batch_size, chunk_size_out, in_size, window_size}, torch::kF16);
Expand Down Expand Up @@ -463,7 +468,7 @@ struct LinearCRFImpl : Module {
};

struct LSTMStackImpl : Module {
LSTMStackImpl(int size) : layer_size(size) {
LSTMStackImpl(int size, Activation act) : layer_size(size), activation(act) {
// torch::nn::LSTM expects/produces [N, T, C] with batch_first == true
rnn1 = register_module("rnn1", LSTM(LSTMOptions(size, size).batch_first(true)));
rnn2 = register_module("rnn2", LSTM(LSTMOptions(size, size).batch_first(true)));
Expand All @@ -488,9 +493,9 @@ struct LSTMStackImpl : Module {
#if USE_KOI
void reserve_working_memory(WorkingMemory &wm) {
auto in_sizes = wm.current_sizes;
switch (auto mode = get_cuda_lstm_mode(0, layer_size)) {
switch (auto mode = get_cuda_lstm_mode(0, layer_size, activation)) {
case LstmMode::CUTLASS_TNC_F16:
if (get_cuda_lstm_mode(1, layer_size) == LstmMode::CUTLASS_TNC_I8) {
if (get_cuda_lstm_mode(1, layer_size, activation) == LstmMode::CUTLASS_TNC_I8) {
wm.reserve(in_sizes, torch::kI8);
}
// fall-through
Expand All @@ -512,7 +517,7 @@ struct LSTMStackImpl : Module {
void run_koi(WorkingMemory &wm) {
utils::ScopedProfileRange spr("lstm_stack", 2);

auto mode = get_cuda_lstm_mode(0, layer_size);
auto mode = get_cuda_lstm_mode(0, layer_size, activation);
if (mode == LstmMode::QUANTISED_NTC) {
return forward_quantized(wm);
} else if (mode == LstmMode::CUBLAS_TN2C) {
Expand Down Expand Up @@ -609,8 +614,9 @@ struct LSTMStackImpl : Module {
// LSTM state h(-1) in either direction.

auto type_id = (in.dtype() == torch::kF16) ? KOI_F16 : KOI_I8;
bool convert_to_i8 = (type_id == KOI_F16) &&
(get_cuda_lstm_mode(1, layer_size) == LstmMode::CUTLASS_TNC_I8);
bool convert_to_i8 =
(type_id == KOI_F16) &&
(get_cuda_lstm_mode(1, layer_size, activation) == LstmMode::CUTLASS_TNC_I8);

int layer_idx = 0;
for (auto &rnn : {rnn1, rnn2, rnn3, rnn4, rnn5}) {
Expand Down Expand Up @@ -802,6 +808,7 @@ struct LSTMStackImpl : Module {
std::vector<at::Tensor> device_scale;
#endif // if USE_KOI
int layer_size;
Activation activation;
LSTM rnn1{nullptr}, rnn2{nullptr}, rnn3{nullptr}, rnn4{nullptr}, rnn5{nullptr};
};

Expand All @@ -827,29 +834,33 @@ TORCH_MODULE(Clamp);

struct CRFModelImpl : Module {
explicit CRFModelImpl(const CRFModelConfig &config) {
Activation activation = config.clamp ? Activation::SWISH_CLAMP : Activation::SWISH;
conv1 = register_module(
"conv1", Convolution(config.num_features, config.conv, 5, 1, activation, false));
conv2 = register_module("conv2", Convolution(config.conv, 16, 5, 1, activation, false));
conv3 = register_module(
"conv3", Convolution(16, config.insize, 19, config.stride, activation, true));
const auto cv = config.convs;
const auto lstm_insize = cv[2].size;
conv1 = register_module("conv1", Convolution(cv[0].insize, cv[0].size, cv[0].winlen,
cv[0].stride, cv[0].activation, false));
conv2 = register_module("conv2", Convolution(cv[1].insize, cv[1].size, cv[1].winlen,
cv[1].stride, cv[1].activation, false));
conv3 = register_module("conv3", Convolution(cv[2].insize, lstm_insize, cv[2].winlen,
cv[2].stride, cv[2].activation, true));

rnns = register_module("rnns", LSTMStack(config.insize));
rnns = register_module("rnns", LSTMStack(lstm_insize, cv[2].activation));

if (config.out_features.has_value()) {
// The linear layer is decomposed into 2 matmuls.
const int decomposition = config.out_features.value();
linear1 = register_module("linear1", LinearCRF(config.insize, decomposition, true));
linear1 = register_module("linear1", LinearCRF(lstm_insize, decomposition, true));
linear2 = register_module("linear2", LinearCRF(decomposition, config.outsize, false));
clamp1 = Clamp(-5.0, 5.0, config.clamp);
encoder = Sequential(conv1, conv2, conv3, rnns, linear1, linear2, clamp1);
} else if ((config.conv == 16) && (config.num_features == 1)) {
linear1 = register_module("linear1", LinearCRF(config.insize, config.outsize, false));
} else if ((config.convs[0].size > 4) && (config.num_features == 1)) {
// v4.x model without linear decomposition
linear1 = register_module("linear1", LinearCRF(lstm_insize, config.outsize, false));
clamp1 = Clamp(-5.0, 5.0, config.clamp);
encoder = Sequential(conv1, conv2, conv3, rnns, linear1, clamp1);
} else {
linear1 = register_module("linear1",
LinearCRF(config.insize, config.outsize, true, true));
// Pre v4 model
linear1 =
register_module("linear1", LinearCRF(lstm_insize, config.outsize, true, true));
encoder = Sequential(conv1, conv2, conv3, rnns, linear1);
}
}
Expand All @@ -869,6 +880,7 @@ struct CRFModelImpl : Module {
conv1->reserve_working_memory(wm);
conv2->reserve_working_memory(wm);
conv3->reserve_working_memory(wm);

rnns->reserve_working_memory(wm);
linear1->reserve_working_memory(wm);
if (linear2) {
Expand All @@ -887,7 +899,6 @@ struct CRFModelImpl : Module {
conv2->run_koi(wm);
conv3->run_koi(wm);
rnns->run_koi(wm);

linear1->run_koi(wm);
if (linear2) {
linear2->run_koi(wm);
Expand Down
Loading

0 comments on commit ac2a90a

Please sign in to comment.