Skip to content

Commit

Permalink
discojs/default_tasks: bundle base models
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed May 6, 2024
1 parent 0e20874 commit 3a7e932
Show file tree
Hide file tree
Showing 7 changed files with 11,427 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import * as tf from '@tensorflow/tfjs'

import type { Model, Task, TaskProvider } from '../index.js'
import { data, models } from '../index.js'
import type { Model, Task, TaskProvider } from '../../index.js'
import { data, models } from '../../index.js'

import baseModel from './model.js'

export const cifar10: TaskProvider = {
getTask (): Task {
Expand Down Expand Up @@ -41,9 +43,10 @@ export const cifar10: TaskProvider = {
},

async getModel (): Promise<Model> {
const mobilenet = await tf.loadLayersModel(
'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json'
)
const mobilenet = await tf.loadLayersModel({
load: async () => Promise.resolve(baseModel),
})

const x = mobilenet.getLayer('global_average_pooling2d_1')
const predictions = tf.layers
.dense({ units: 10, activation: 'softmax', name: 'denseModified' })
Expand Down
Loading

0 comments on commit 3a7e932

Please sign in to comment.