Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
elgiano committed Jul 20, 2023
1 parent afc6245 commit ea7129b
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 125 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
### v0.0.3-alpha
changed implementation to independent per-UGen model instance
- NN.load: scsynth only loads model to read info, real loading is done in UGen
- attributes interface: now in UGen, no more set and get methods
- attributes interface: now only in UGen, no more set and get methods
- added silent warmup pass option to UGen
- interface: removed blockSize as first arg, added debug, warmup and attributes args
- UGen interface: blockSize moved from first to second arg, first is inputs, added debug, warmup and attributes args

### v0.0.2-alpha
- updated backend from nn_tilde: using a looping thread
Expand Down
9 changes: 5 additions & 4 deletions plugins/NNModel/cpp/NNUGens.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ void model_perform_load(NN* nn, bool warmup) {
Print("NNUGen: warming up model\n", path);
nn->warmupModel();
}
nn->m_loaded = true;
if (nn->m_debug >= Debug::all)
Print("NNUGen: loaded %s\n", path);
}
Expand Down Expand Up @@ -141,7 +142,7 @@ void model_perform_loop(NN *nn_instance, bool warmup) {

void NNUGen::next(int nSamples) {

if (!m_sharedData->m_modelDesc->is_loaded()) {
if (!m_sharedData->m_loaded) {
ClearUnitOutputs(this, nSamples);
return;
};
Expand Down Expand Up @@ -195,7 +196,7 @@ NN::NN(
m_bufferSize(bufferSize), m_debug(debug),
m_compute_thread(nullptr),
m_data_available_lock(0), m_result_available_lock(1),
m_should_stop_perform_thread(false)
m_should_stop_perform_thread(false), m_loaded(false)
{
m_inDim = m_method->inDim;
m_outDim = m_method->outDim;
Expand Down Expand Up @@ -223,10 +224,10 @@ NNUGen::NNUGen():
m_useThread = mWorld->mRealTime;
int modelHigherRatio = modelDesc->getHigherRatio();
if (m_bufferSize < 0) {
// NO THREAD MODE
m_useThread = false;
m_bufferSize = modelHigherRatio;
} else if (m_bufferSize == 0) {
// NO THREAD MODE
m_useThread = false;
m_bufferSize = modelHigherRatio;
} else if (m_bufferSize < modelHigherRatio) {
m_bufferSize = modelHigherRatio;
Expand Down
7 changes: 2 additions & 5 deletions plugins/NNModel/cpp/NNUGens.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@ class NNSetAttr {
// called in audio thread: check trig, update value and flag
void update(Unit* unit, int nSamples);

const char* getName() const { return attr->name.c_str(); }
bool changed() const { return valUpdated; }
// called before model_perform
const char* getName() {
return attr->name.c_str();
}
std::string getStrValue() {
valUpdated = false;
if (attr->type == NNAttributeType::typeBool)
Expand Down Expand Up @@ -71,8 +69,7 @@ class NN {
std::vector<NNSetAttr> m_attributes;
Backend m_model;
bool m_should_stop_perform_thread;
bool m_enabled;
bool m_useGpu;
bool m_loaded;
};

class NNUGen : public SCUnit {
Expand Down
55 changes: 27 additions & 28 deletions plugins/NNModel/sc/NN.sc
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
NN {

classvar rtModelStore, rtModelsInfo;
*initClass {
rtModelStore = IdentityDictionary[];
// store model info by path
rtModelsInfo = IdentityDictionary[]
rtModelStore = IdentityDictionary[];
// store model info by path
rtModelsInfo = IdentityDictionary[]
}

*models {
Expand All @@ -20,14 +19,14 @@ NN {
if(this.isNRT, this.nrtModelStore, rtModelStore)[key] = model;
}

*prCacheInfo { |info|
*prCacheInfo { |info|
var cache = if(this.isNRT, this.nrtModelsInfo, rtModelsInfo);
var path = info.path.asSymbol;
if (cache[path].notNil) {
"NN: overriding cached info for '%'".format(path).warn;
};
cache[path] = info;
}
if (cache[path].notNil) {
"NN: overriding cached info for '%'".format(path).warn;
};
cache[path] = info;
}
*prGetCachedInfo { |path|
^if(this.isNRT, this.nrtModelsInfo, rtModelsInfo)[path.standardizePath.asSymbol]
}
Expand All @@ -44,31 +43,31 @@ NN {
};
}

*load { |key, path, id(-1), server(Server.default), action|
var model = this.model(key);
if (path.isKindOf(String).not) {
Error("NN.load: path needs to be a string, got: %").format(path).throw
};
if (model.isNil or: {model.isLoaded.not}) {
if (this.isNRT) {
var info = this.prGetCachedInfo(path) ?? {
Error("NN.load (nrt): model info not found for %".format(path)).throw;
};
model = NNModel.fromInfo(info, this.nextModelID);
*load { |key, path, id(-1), server(Server.default), action|
var model = this.model(key);
if (path.isKindOf(String).not) {
Error("NN.load: path needs to be a string, got: %").format(path).throw
};
if (model.isNil or: {model.isLoaded.not}) {
if (this.isNRT) {
var info = this.prGetCachedInfo(path) ?? {
Error("NN.load (nrt): model info not found for %".format(path)).throw;
};
model = NNModel.fromInfo(info, this.nextModelID);
this.prPut(key, model);
} {
model = NNModel.load(path, id, server, action: { |m|
this.prPut(key, m);
} {
model = NNModel.load(path, id, server, action: { |m|
this.prPut(key, m);
// call action after adding to registry: in case action needs key
action.value(m);
});
};
};
};
};
if (this.isNRT) {
server.sendMsg(*model.loadMsg);
}
^model;
}
^model;
}

*describeAll { this.models.do(_.describe) }

Expand Down
84 changes: 42 additions & 42 deletions plugins/NNModel/sc/NNModel.sc
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ NNModel {

*new { ^nil }

minBufferSize { ^if (info.isNil) { nil } { info.minBufferSize } }
attributes { ^if(info.isNil, nil, info.attributes) }
minBufferSize { ^if (info.isNil) { nil } { info.minBufferSize } }
attributes { ^if(info.isNil) { nil } { info.attributes } }
attrIdx { |attrName|
var attrs = this.attributes ?? { ^nil };
^attrs.indexOf(attrName);
Expand All @@ -19,7 +19,7 @@ NNModel {
key { ^NN.keyForModel(this) }

method { |name|
var method;
var method;
this.methods ?? { Error("NNModel % has no methods.".format(this.key)).throw };
^this.methods.detect { |m| m.name == name };
}
Expand Down Expand Up @@ -47,17 +47,17 @@ NNModel {

model = super.newCopyArgs(server);

forkIfNeeded {
server.sync(bundles: [loadMsg]);
// server writes info file: read it
protect {
model.initFromFile(infoFile);
forkIfNeeded {
server.sync(bundles: [loadMsg]);
// server writes info file: read it
protect {
model.initFromFile(infoFile);
ServerBoot.add(model, server);
action.(model)
} {
File.delete(infoFile);
}
};
action.(model)
} {
File.delete(infoFile);
}
};

^model;
}
Expand All @@ -69,18 +69,18 @@ NNModel {
^super.newCopyArgs(server).initFromInfo(info, overrideId);
}

initFromFile { |infoFile|
var info = NNModelInfo.fromFile(infoFile);
this.initFromInfo(info);
NN.prCacheInfo(info);
}
initFromFile { |infoFile|
var info = NNModelInfo.fromFile(infoFile);
this.initFromInfo(info);
NN.prCacheInfo(info);
}

initFromInfo { |infoObj, overrideId|
info = infoObj;
path = info.path;
idx = overrideId ? info.idx;
methods = info.methods.collect { |m| m.copyForModel(this) }
}
initFromInfo { |infoObj, overrideId|
info = infoObj;
path = info.path;
idx = overrideId ? info.idx;
methods = info.methods.collect { |m| m.copyForModel(this) }
}

loadMsg { |newPath, infoFile|
^NN.loadMsg(idx, newPath ? path, infoFile)
Expand All @@ -104,29 +104,29 @@ NNModel {
}

prErrIfNoServer { |funcName|
if (server.isNil) {
Error("%: NNModel(%) is not bound to a server, can't dumpInfo. Is it a NRT model?"
.format(funcName, this.key)).throw
if (server.isNil) {
Error("%: NNModel(%) is not bound to a server, can't dumpInfo. Is it a NRT model?"
.format(funcName, this.key)).throw
};
}
}

NNModelInfo {
var <idx, <path, <minBufferSize, <methods, <attributes;
*new {}

*fromFile { |infoFile|
if (File.exists(infoFile).not) {
Error("NNModelInfo: can't load info file '%'".format(infoFile)).throw;
} {
var yaml = File.readAllString(infoFile).parseYAML[0];
^super.new.initFromDict(yaml)
}
}
*new {}

*fromFile { |infoFile|
if (File.exists(infoFile).not) {
Error("NNModelInfo: can't load info file '%'".format(infoFile)).throw;
} {
var yaml = File.readAllString(infoFile).parseYAML[0];
^super.new.initFromDict(yaml)
}
}
*fromDict { |infoDict|
^super.new.initFromDict(infoDict);
}
initFromDict { |yaml|
initFromDict { |yaml|
idx = yaml["idx"].asInteger;
path = yaml["modelPath"];
minBufferSize = yaml["minBufferSize"].asInteger;
Expand All @@ -137,7 +137,7 @@ NNModelInfo {
NNModelMethod(nil, name, n, inDim, outDim);
};
attributes = yaml["attributes"].collect(_.asSymbol) ?? { [] }
}
}

describe {
"path: %".format(this.path).postln;
Expand All @@ -154,9 +154,9 @@ NNModelMethod {

*new { |...args| ^super.newCopyArgs(*args) }

copyForModel { |model|
^this.class.newCopyArgs(model, name, idx, numInputs, numOutputs)
}
copyForModel { |model|
^this.class.newCopyArgs(model, name, idx, numInputs, numOutputs)
}

printOn { |stream|
stream << "%(%: % in, % out)".format(this.class.name, name, numInputs, numOutputs);
Expand Down
8 changes: 4 additions & 4 deletions plugins/NNModel/sc/NNUGens.sc
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
NNUGen : MultiOutUGen {

// enum UGenInputs { modelIdx=0, methodIdx, bufSize, warmup, debug, inputs };
// enum UGenInputs { modelIdx=0, methodIdx, bufSize, warmup, debug, inputs };
*ar { |modelIdx, methodIdx, bufferSize, numOutputs, warmup, debug, inputs|
^this.new1('audio', modelIdx, methodIdx, bufferSize, warmup, debug, *inputs)
.initOutputs(numOutputs, 'audio');
}

checkInputs {
// modelIdx, methodIdx and bufferSize are not modulatable
['modelIdx', 'methodIdx', 'bufferSize'].do { |name, n|
if (inputs[n].rate != \scalar) {
if (inputs[n].rate != \scalar) {
^": '%' is not modulatable. Got: %.".format(name, inputs[n]);
}
}
Expand All @@ -19,7 +19,7 @@ NNUGen : MultiOutUGen {

+ NNModelMethod {

ar { |inputs, bufferSize=0, warmup=0, debug=0, attributes(#[])|
ar { |inputs, bufferSize(-1), warmup=0, debug=0, attributes(#[])|
var attrParams;
inputs = inputs.asArray;
if (inputs.size != this.numInputs) {
Expand Down
Loading

0 comments on commit ea7129b

Please sign in to comment.